File size: 2,822 Bytes
8c972d1
1d8afda
c4c69b3
8c972d1
 
c4c69b3
8c972d1
 
c4c69b3
8c972d1
 
c4c69b3
8c972d1
c4c69b3
 
 
 
 
8c972d1
 
15381fd
c4c69b3
db6f8f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c4c69b3
 
 
 
 
 
 
 
 
 
 
 
 
15381fd
c4c69b3
 
 
 
 
 
 
 
 
 
 
 
 
 
15381fd
 
c4c69b3
15381fd
c4c69b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15381fd
db6f8f4
 
 
 
 
 
 
 
15381fd
db6f8f4
c4c69b3
db6f8f4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import os
import gradio as gr
import difflib
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM

# Hugging Face token (needed if model is private/gated)
HF_TOKEN = os.getenv("HF_TOKEN")

# Your fine-tuned model on the Hub
model_id = "karthiksagarn/llama3-3.2b-finetuned-financial"

# Load model + tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=HF_TOKEN)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    use_auth_token=HF_TOKEN,
    device_map="auto"
)

pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device_map="auto")

# Canonical categories
labels = [
    "Food & Drinks",
    "Travel & Transport",
    "Shopping",
    "Bills & Utilities",
    "Entertainment",
    "Health & Fitness",
    "Groceries",
    "Education",
    "Income",
    "Investments",
    "Withdrawals",
    "Miscellaneous"
]

# Few-shot examples (adjust to your dataset)
FEW_SHOT_EXAMPLES = [
    ("Sent Rs.510.00 From HDFC Bank A/C *0552 To Swiggy Limited", "Food & Drinks"),
    ("Salary credited from ACME Technologies", "Income"),
    ("Paid electricity bill online", "Bills & Utilities"),
    ("Bought vegetables at the local market", "Groceries"),
    ("ATM cash withdrawal 3000", "Withdrawals"),
]

examples_text = "\n\n".join(f"Description: {d}\nCategory: {c}" for d, c in FEW_SHOT_EXAMPLES)

# Classifier function
def classify_transaction(description: str) -> str:
    prompt = (
        "You are a classifier. Return EXACTLY one of the labels below and NOTHING ELSE.\n\n"
        f"Labels: {', '.join(labels)}\n\n"
        "Examples:\n"
        f"{examples_text}\n\n"
        "Now classify the following:\n"
        f"Description: {description}\n"
        "Category:"
    )

    out = pipe(
        prompt,
        max_new_tokens=8,
        do_sample=False,
        temperature=0.0,
    )

    generated = out[0]["generated_text"].split("Category:")[-1].strip()
    generated = generated.split("\n")[0].strip()

    # 1. Exact match
    if generated in labels:
        return generated

    # 2. Substring match
    for lbl in labels:
        if lbl.lower() in generated.lower() or generated.lower() in lbl.lower():
            return lbl

    # 3. Fuzzy match
    match = difflib.get_close_matches(generated, labels, n=1, cutoff=0.55)
    if match:
        return match[0]

    # 4. Fallback
    return "Miscellaneous"

# Gradio UI
demo = gr.Interface(
    fn=classify_transaction,
    inputs=gr.Textbox(label="Transaction Description", placeholder="Enter bank transaction..."),
    outputs=gr.Label(label="Predicted Category"),
    title="🏦 Bank Transaction Classifier",
    description="Classify bank transactions into categories using a fine-tuned LLaMA model."
)

if __name__ == "__main__":
    # For HF Spaces, keep share=False
    demo.launch()