import os import gradio as gr import difflib from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM # Hugging Face token (needed if model is private/gated) HF_TOKEN = os.getenv("HF_TOKEN") # Your fine-tuned model on the Hub model_id = "karthiksagarn/llama3-3.2b-finetuned-financial" # Load model + tokenizer tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=HF_TOKEN) model = AutoModelForCausalLM.from_pretrained( model_id, use_auth_token=HF_TOKEN, device_map="auto" ) pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device_map="auto") # Canonical categories labels = [ "Food & Drinks", "Travel & Transport", "Shopping", "Bills & Utilities", "Entertainment", "Health & Fitness", "Groceries", "Education", "Income", "Investments", "Withdrawals", "Miscellaneous" ] # Few-shot examples (adjust to your dataset) FEW_SHOT_EXAMPLES = [ ("Sent Rs.510.00 From HDFC Bank A/C *0552 To Swiggy Limited", "Food & Drinks"), ("Salary credited from ACME Technologies", "Income"), ("Paid electricity bill online", "Bills & Utilities"), ("Bought vegetables at the local market", "Groceries"), ("ATM cash withdrawal 3000", "Withdrawals"), ] examples_text = "\n\n".join(f"Description: {d}\nCategory: {c}" for d, c in FEW_SHOT_EXAMPLES) # Classifier function def classify_transaction(description: str) -> str: prompt = ( "You are a classifier. Return EXACTLY one of the labels below and NOTHING ELSE.\n\n" f"Labels: {', '.join(labels)}\n\n" "Examples:\n" f"{examples_text}\n\n" "Now classify the following:\n" f"Description: {description}\n" "Category:" ) out = pipe( prompt, max_new_tokens=8, do_sample=False, temperature=0.0, ) generated = out[0]["generated_text"].split("Category:")[-1].strip() generated = generated.split("\n")[0].strip() # 1. Exact match if generated in labels: return generated # 2. Substring match for lbl in labels: if lbl.lower() in generated.lower() or generated.lower() in lbl.lower(): return lbl # 3. Fuzzy match match = difflib.get_close_matches(generated, labels, n=1, cutoff=0.55) if match: return match[0] # 4. Fallback return "Miscellaneous" # Gradio UI demo = gr.Interface( fn=classify_transaction, inputs=gr.Textbox(label="Transaction Description", placeholder="Enter bank transaction..."), outputs=gr.Label(label="Predicted Category"), title="🏦 Bank Transaction Classifier", description="Classify bank transactions into categories using a fine-tuned LLaMA model." ) if __name__ == "__main__": # For HF Spaces, keep share=False demo.launch()