nitya001 commited on
Commit
31e91dc
·
verified ·
1 Parent(s): d959b36

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +147 -0
app.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import gradio as gr
4
+
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM
6
+ from peft import PeftModel
7
+
8
+ # --------- CONFIG ---------
9
+ BASE_MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
10
+ LORA_MODEL_ID = "nitya001/autotrain-fngb8-wqn4c" # <-- change if your repo name differs
11
+
12
+ MAX_NEW_TOKENS = 128
13
+ TEMPERATURE = 0.7
14
+ TOP_P = 0.9
15
+
16
+ SYSTEM_PROMPT = (
17
+ "You are a helpful banking and loan support assistant. "
18
+ "You answer short, clear, and factual responses about UTRs, EMIs, "
19
+ "loan summaries, and payment issues based ONLY on the given question. "
20
+ "If you don't know something (like actual live data), say that you "
21
+ "cannot access real-time systems and answer generically."
22
+ )
23
+
24
+ # --------- LOAD MODEL ---------
25
+ device = "cuda" if torch.cuda.is_available() else "cpu"
26
+
27
+ print(f"Loading base model: {BASE_MODEL_ID} on {device}...")
28
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID)
29
+
30
+ # TinyLlama uses eos_token as pad_token sometimes; ensure it's set
31
+ if tokenizer.pad_token is None:
32
+ tokenizer.pad_token = tokenizer.eos_token
33
+
34
+ base_model = AutoModelForCausalLM.from_pretrained(
35
+ BASE_MODEL_ID,
36
+ torch_dtype=torch.float32,
37
+ device_map=None,
38
+ )
39
+
40
+ print(f"Loading LoRA adapter: {LORA_MODEL_ID}...")
41
+ model = PeftModel.from_pretrained(
42
+ base_model,
43
+ LORA_MODEL_ID,
44
+ )
45
+ model.to(device)
46
+ model.eval()
47
+
48
+ # --------- CHAT LOGIC ---------
49
+ def format_chat_history(history, user_message):
50
+ """
51
+ Convert chat history + new user message into a single prompt string.
52
+ For now, we keep it simple: a system prompt + last few turns.
53
+ """
54
+ parts = [f"System: {SYSTEM_PROMPT}"]
55
+ for old_user, old_bot in history:
56
+ parts.append(f"User: {old_user}")
57
+ parts.append(f"Assistant: {old_bot}")
58
+ parts.append(f"User: {user_message}")
59
+ parts.append("Assistant:")
60
+ return "\n".join(parts)
61
+
62
+
63
+ def generate_reply(user_message, history):
64
+ if not user_message.strip():
65
+ return history
66
+
67
+ # Build prompt from history
68
+ prompt = format_chat_history(history, user_message)
69
+
70
+ inputs = tokenizer(
71
+ prompt,
72
+ return_tensors="pt",
73
+ truncation=True,
74
+ max_length=512,
75
+ ).to(device)
76
+
77
+ with torch.no_grad():
78
+ output_ids = model.generate(
79
+ **inputs,
80
+ max_new_tokens=MAX_NEW_TOKENS,
81
+ do_sample=True,
82
+ temperature=TEMPERATURE,
83
+ top_p=TOP_P,
84
+ pad_token_id=tokenizer.pad_token_id,
85
+ eos_token_id=tokenizer.eos_token_id,
86
+ )
87
+
88
+ # Decode only the newly generated part
89
+ full_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
90
+
91
+ # Naive way: take everything after the last "Assistant:" marker
92
+ if "Assistant:" in full_text:
93
+ bot_reply = full_text.split("Assistant:")[-1].strip()
94
+ else:
95
+ bot_reply = full_text.strip()
96
+
97
+ history.append((user_message, bot_reply))
98
+ return history
99
+
100
+
101
+ # --------- GRADIO UI ---------
102
+ with gr.Blocks(title="TinyLoan Assistant") as demo:
103
+ gr.Markdown(
104
+ """
105
+ # 💬 TinyLoan Assistant (TinyLlama + LoRA)
106
+ Ask about UTRs, EMIs, loan summaries, payment issues, etc.
107
+
108
+ > **Note:** This demo does not access real bank systems.
109
+ > It answers based on patterns learned from example data.
110
+ """
111
+ )
112
+
113
+ chatbot = gr.Chatbot(
114
+ label="Chat",
115
+ height=400,
116
+ type="pairs",
117
+ )
118
+
119
+ with gr.Row():
120
+ user_input = gr.Textbox(
121
+ show_label=False,
122
+ placeholder="Type your question, e.g. 'What is my latest UTR?'",
123
+ scale=4,
124
+ )
125
+ send_btn = gr.Button("Send", scale=1)
126
+
127
+ clear_btn = gr.Button("Clear chat")
128
+
129
+ def respond(message, chat_history):
130
+ if chat_history is None:
131
+ chat_history = []
132
+ return generate_reply(message, chat_history)
133
+
134
+ send_btn.click(
135
+ respond,
136
+ inputs=[user_input, chatbot],
137
+ outputs=[chatbot],
138
+ )
139
+ user_input.submit(
140
+ respond,
141
+ inputs=[user_input, chatbot],
142
+ outputs=[chatbot],
143
+ )
144
+
145
+ clear_btn.click(lambda: [], outputs=[chatbot])
146
+
147
+ demo.launch()