karthiksagarn's picture
Update app.py
c4c69b3 verified
import os
import gradio as gr
import difflib
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
# Hugging Face token (needed if model is private/gated)
HF_TOKEN = os.getenv("HF_TOKEN")
# Your fine-tuned model on the Hub
model_id = "karthiksagarn/llama3-3.2b-finetuned-financial"
# Load model + tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=HF_TOKEN)
model = AutoModelForCausalLM.from_pretrained(
model_id,
use_auth_token=HF_TOKEN,
device_map="auto"
)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device_map="auto")
# Canonical categories
labels = [
"Food & Drinks",
"Travel & Transport",
"Shopping",
"Bills & Utilities",
"Entertainment",
"Health & Fitness",
"Groceries",
"Education",
"Income",
"Investments",
"Withdrawals",
"Miscellaneous"
]
# Few-shot examples (adjust to your dataset)
FEW_SHOT_EXAMPLES = [
("Sent Rs.510.00 From HDFC Bank A/C *0552 To Swiggy Limited", "Food & Drinks"),
("Salary credited from ACME Technologies", "Income"),
("Paid electricity bill online", "Bills & Utilities"),
("Bought vegetables at the local market", "Groceries"),
("ATM cash withdrawal 3000", "Withdrawals"),
]
examples_text = "\n\n".join(f"Description: {d}\nCategory: {c}" for d, c in FEW_SHOT_EXAMPLES)
# Classifier function
def classify_transaction(description: str) -> str:
prompt = (
"You are a classifier. Return EXACTLY one of the labels below and NOTHING ELSE.\n\n"
f"Labels: {', '.join(labels)}\n\n"
"Examples:\n"
f"{examples_text}\n\n"
"Now classify the following:\n"
f"Description: {description}\n"
"Category:"
)
out = pipe(
prompt,
max_new_tokens=8,
do_sample=False,
temperature=0.0,
)
generated = out[0]["generated_text"].split("Category:")[-1].strip()
generated = generated.split("\n")[0].strip()
# 1. Exact match
if generated in labels:
return generated
# 2. Substring match
for lbl in labels:
if lbl.lower() in generated.lower() or generated.lower() in lbl.lower():
return lbl
# 3. Fuzzy match
match = difflib.get_close_matches(generated, labels, n=1, cutoff=0.55)
if match:
return match[0]
# 4. Fallback
return "Miscellaneous"
# Gradio UI
demo = gr.Interface(
fn=classify_transaction,
inputs=gr.Textbox(label="Transaction Description", placeholder="Enter bank transaction..."),
outputs=gr.Label(label="Predicted Category"),
title="🏦 Bank Transaction Classifier",
description="Classify bank transactions into categories using a fine-tuned LLaMA model."
)
if __name__ == "__main__":
# For HF Spaces, keep share=False
demo.launch()