File size: 3,406 Bytes
abf6e04
3b329f8
0b139fc
 
 
abf6e04
3b329f8
abf6e04
1ff47dc
3b329f8
 
 
 
323e55b
3b329f8
 
 
1ff47dc
abf6e04
1ff47dc
 
 
3b329f8
 
d287b8a
3b329f8
 
 
 
0b139fc
334f958
3b329f8
1ff47dc
3b329f8
334f958
7a94bba
1ff47dc
 
334f958
1ff47dc
334f958
3b329f8
 
 
 
 
334f958
3b329f8
334f958
 
cbb6831
1ff47dc
 
 
cbb6831
334f958
1ff47dc
 
 
 
24f83a8
1ff47dc
 
 
 
 
 
 
24f83a8
cc215f9
1ff47dc
 
 
 
323e55b
1ff47dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4ae273b
abf6e04
1ff47dc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import json
import re

MODEL_ID = "rawsun00001/banking-sms-json-parser-v6-merged"

print("🔄 Loading banking‑SMS JSON parser model...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    device_map="auto" if torch.cuda.is_available() else None
)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
print("✅ Model loaded successfully!")

def parse_banking_sms(raw_text: str) -> dict:
    sms_text = " ".join(raw_text.strip().split())
    prompt = sms_text + "|"
    inputs = tokenizer(prompt, return_tensors="pt")
    if torch.cuda.is_available():
        inputs = {k: v.cuda() for k, v in inputs.items()}
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=120,
            do_sample=False,
            repetition_penalty=1.05,
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )
    decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
    json_part = decoded[len(prompt):].strip()
    m = re.search(r"\{[^{}]+\}", json_part)
    if m:
        try:
            parsed = json.loads(m.group())
            return {
                "date": parsed.get("date"),
                "type": parsed.get("type"),
                "amount": parsed.get("amount"),
                "category": parsed.get("category"),
                "last4": parsed.get("last4"),
                "is_transaction": parsed.get("is_transaction", False),
            }
        except json.JSONDecodeError:
            pass
    return {
        "date": None, "type": None,
        "amount": None, "category": None,
        "last4": None, "is_transaction": False
    }

def predict(raw_text: str) -> str:
    parsed = parse_banking_sms(raw_text)
    if parsed["is_transaction"]:
        summary = (
            f"✅ Transaction Detected!\n\n"
            f"- 📅 Date: **{parsed.get('date', 'N/A')}**\n"
            f"- 💳 Type: **{parsed.get('type', '').title()}**\n"
            f"- 💰 Amount: **{parsed.get('amount')}**\n"
            f"- 🏪 Category: **{parsed.get('category')}**\n"
            f"- 🔢 Last 4 Digits: **{parsed.get('last4')}**\n\n"
            "**Full Parsed JSON:**\n```json\n"
            f"{json.dumps(parsed, indent=2)}\n```"
        )
    else:
        summary = (
            "ℹ️ Non‑transactional SMS / Promotional / Info message.\n\n"
            "**Parsed Classification JSON:**\n```json\n"
            f"{json.dumps(parsed, indent=2)}\n```"
        )
    return summary

iface = gr.Interface(
    fn=predict,
    inputs=gr.Textbox(
        lines=3,
        placeholder="Paste your banking SMS or email here…",
        label="Input SMS / Email"
    ),
    outputs=gr.Markdown(label="Parsed Output"),
    title="🏦 Banking SMS JSON Parser",
    description=(
        "Paste any banking SMS (or email) below — the app will detect transaction "
        "vs non-transaction, and display structured JSON output."
    ),
    allow_flagging="never",
    analytics_enabled=False,  # disable Gradio telemetry in Spaces
)

if __name__ == "__main__":
    iface.launch(server_name="0.0.0.0", server_port=7860)