nitya001 commited on
Commit
29684f5
·
verified ·
1 Parent(s): d3326e2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -67
app.py CHANGED
@@ -1,102 +1,108 @@
1
  import torch
2
  import gradio as gr
3
- from transformers import AutoTokenizer, AutoModelForCausalLM
4
  from peft import PeftModel
5
 
6
- # ---------------- CONFIG ---------------- #
 
 
7
 
8
- BASE_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" # Base model
9
- LORA_REPO = "nitya001/autotrain-4n1y9-5ekvs" # Your AutoTrain LoRA repo
10
 
11
- SYSTEM_PROMPT = (
12
- "You are a helpful banking and loan support assistant. "
13
- "You answer short, clear, and factual responses about UTRs, EMIs, loan summaries, "
14
- "payment issues, and basic loan help. If unsure, respond generically."
15
- )
16
-
17
- device = "cpu"
18
-
19
-
20
- # ---------------- LOAD TOKENIZER ---------------- #
21
-
22
- print("Loading tokenizer...")
23
  tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
24
 
25
- if tokenizer.pad_token is None:
26
- tokenizer.pad_token = tokenizer.eos_token
 
27
 
28
-
29
- # ---------------- LOAD BASE MODEL ---------------- #
30
-
31
- print("Loading base model...")
32
  base_model = AutoModelForCausalLM.from_pretrained(
33
  BASE_MODEL,
34
- torch_dtype=torch.float32,
35
- device_map=device,
36
- )
37
-
38
- # ---------------- LOAD LORA ADAPTER ---------------- #
39
-
40
- print(f"Loading LoRA adapter from {LORA_REPO} ...")
41
- model = PeftModel.from_pretrained(
42
- base_model,
43
- LORA_REPO,
44
  )
 
45
 
 
 
 
46
  model.eval()
47
 
48
-
49
- # ---------------- CHAT FUNCTION ---------------- #
50
-
51
- def chat_fn(message, history):
52
  """
53
- Gradio ChatInterface callback.
54
- history: list of [user, bot]
55
  """
56
-
57
- # Build conversation text
58
- conversation = f"System: {SYSTEM_PROMPT}\n"
59
- for user_msg, bot_msg in history:
60
- conversation += f"User: {user_msg}\nAssistant: {bot_msg}\n"
61
- conversation += f"User: {message}\nAssistant:"
62
-
63
- inputs = tokenizer(conversation, return_tensors="pt").to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
  with torch.no_grad():
66
  outputs = model.generate(
67
  **inputs,
68
- max_new_tokens=150,
69
  do_sample=True,
70
  top_p=0.9,
71
  temperature=0.7,
72
  pad_token_id=tokenizer.eos_token_id,
73
  )
74
 
75
- full_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
76
 
77
- # Extract only the latest answer after the last "Assistant:"
78
- if "Assistant:" in full_output:
79
- reply = full_output.split("Assistant:")[-1].strip()
80
- else:
81
- reply = full_output.strip()
82
 
83
- history.append((message, reply))
84
- return history, history
85
 
 
 
 
86
 
87
- # ---------------- GRADIO UI ---------------- #
 
 
 
 
88
 
89
- demo = gr.ChatInterface(
90
- fn=chat_fn,
91
- title="💬 TinyLoan Assistant (TinyLlama + AutoTrain LoRA)",
92
- description="Ask about UTR, loan summaries, EMIs, transactions, or payment issues.",
93
- examples=[
94
- "What is my latest UTR?",
95
- "Generate my loan summary.",
96
- "Show my transactions.",
97
- "My payment is stuck, what should I do?",
98
- ],
99
- )
100
 
101
  if __name__ == "__main__":
102
  demo.launch()
 
1
  import torch
2
  import gradio as gr
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
  from peft import PeftModel
5
 
6
+ # ----- CONFIG -----
7
+ BASE_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
8
+ LORA_REPO = "nitya001/autotrain-4n1y9-5ekvs"
9
 
10
+ device = "cuda" if torch.cuda.is_available() else "cpu"
 
11
 
12
+ # ----- LOAD TOKENIZER -----
 
 
 
 
 
 
 
 
 
 
 
13
  tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
14
 
15
+ # Make sure we have a pad token
16
+ if tokenizer.pad_token_id is None:
17
+ tokenizer.pad_token_id = tokenizer.eos_token_id
18
 
19
+ # ----- LOAD BASE MODEL -----
 
 
 
20
  base_model = AutoModelForCausalLM.from_pretrained(
21
  BASE_MODEL,
22
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
 
 
 
 
 
 
 
 
 
23
  )
24
+ base_model.to(device)
25
 
26
+ # ----- LOAD LORA ADAPTER -----
27
+ model = PeftModel.from_pretrained(base_model, LORA_REPO)
28
+ model.to(device)
29
  model.eval()
30
 
31
+ # ----- HELPER: BUILD PROMPT FROM HISTORY -----
32
+ def build_prompt(history, user_message: str) -> str:
 
 
33
  """
34
+ history: list of (user, assistant) pairs
35
+ user_message: latest user text
36
  """
37
+ chat = ""
38
+ if history is None:
39
+ history = []
40
+
41
+ # If your TinyLlama uses chat tokens like <|user|> / <|assistant|>,
42
+ # we format the conversation that way.
43
+ for user, assistant in history:
44
+ if not user and not assistant:
45
+ continue
46
+ chat += f"<|user|>\n{user}\n<|assistant|>\n{assistant}\n"
47
+
48
+ chat += f"<|user|>\n{user_message}\n<|assistant|>\n"
49
+ return chat
50
+
51
+ # ----- CHAT FUNCTION (THIS IS WHAT GRADIO CALLS) -----
52
+ def chat_fn(user_message, history):
53
+ if history is None:
54
+ history = []
55
+
56
+ prompt = build_prompt(history, user_message)
57
+
58
+ inputs = tokenizer(
59
+ prompt,
60
+ return_tensors="pt",
61
+ truncation=True,
62
+ max_length=2048,
63
+ ).to(device)
64
 
65
  with torch.no_grad():
66
  outputs = model.generate(
67
  **inputs,
68
+ max_new_tokens=256,
69
  do_sample=True,
70
  top_p=0.9,
71
  temperature=0.7,
72
  pad_token_id=tokenizer.eos_token_id,
73
  )
74
 
75
+ # Only the newly generated tokens after the prompt
76
+ generated_ids = outputs[0][inputs["input_ids"].shape[-1] :]
77
+ answer = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
78
 
79
+ # Update history as list of (user, assistant)
80
+ history.append((user_message, answer))
 
 
 
81
 
82
+ # 🔴 IMPORTANT: return ONLY `history`, NOT `(history, history)` 🔴
83
+ return history
84
 
85
+ # ----- GRADIO UI -----
86
+ with gr.Blocks() as demo:
87
+ gr.Markdown("## TinyLlama + LoRA – Custom Chatbot")
88
 
89
+ chatbot = gr.Chatbot(
90
+ label="Chat",
91
+ type="tuple", # list of (user, assistant)
92
+ height=500,
93
+ )
94
 
95
+ msg = gr.Textbox(
96
+ label="Your message",
97
+ placeholder="Ask something...",
98
+ )
99
+
100
+ clear = gr.Button("Clear")
101
+
102
+ # On submit: send msg + current chatbot history into chat_fn
103
+ # and update ONLY the chatbot with the returned history
104
+ msg.submit(chat_fn, inputs=[msg, chatbot], outputs=[chatbot])
105
+ clear.click(lambda: [], None, chatbot)
106
 
107
  if __name__ == "__main__":
108
  demo.launch()