simran40 commited on
Commit
e94cf6c
·
verified ·
1 Parent(s): 7be0281

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +292 -61
app.py CHANGED
@@ -1,62 +1,293 @@
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
3
- import torch
4
-
5
- # -------------------------------
6
- # Load a lightweight GPT-like model (CPU)
7
- # -------------------------------
8
- model_name = "microsoft/DialoGPT-medium"
9
-
10
- tokenizer = AutoTokenizer.from_pretrained(model_name)
11
- model = AutoModelForCausalLM.from_pretrained(model_name)
12
-
13
- # -------------------------------
14
- # Chat function
15
- # -------------------------------
16
- def generate_response(history, message):
17
- inputs = tokenizer.encode(message + tokenizer.eos_token, return_tensors="pt")
18
-
19
- outputs = model.generate(
20
- inputs,
21
- max_length=300,
22
- pad_token_id=tokenizer.eos_token_id,
23
- do_sample=True,
24
- top_p=0.90,
25
- temperature=0.75
26
- )
27
-
28
- reply = tokenizer.decode(outputs[:, inputs.shape[-1]:][0], skip_special_tokens=True)
29
-
30
- history.append((message, reply))
31
- return history
32
-
33
- # -------------------------------
34
- # Interface (Creative UI)
35
- # -------------------------------
36
- with gr.Blocks(
37
- theme=gr.themes.Soft(
38
- primary_hue="purple",
39
- secondary_hue="blue",
40
- neutral_hue="slate"
41
- )
42
- ) as demo:
43
-
44
- # Header
45
- gr.Markdown("""
46
- <h1 style='text-align:center; color:#6D28D9;'>🤖 GPT-Lite Chatbot</h1>
47
- <p style='text-align:center; font-size:18px;'>
48
- A smart, lightweight, multi-language chatbot that runs <b>100% on CPU</b>.
49
- Ask anything — I'll answer like a mini GPT!
50
- </p>
51
- <br>
52
- """)
53
-
54
- chatbot = gr.Chatbot(height=450, label="ChatGPT-Style Assistant")
55
- user_input = gr.Textbox(placeholder="Type your message here...", label="Your Message")
56
- clear_btn = gr.Button("Clear Chat")
57
-
58
- user_input.submit(generate_response, [chatbot, user_input], chatbot)
59
- user_input.submit(lambda: "", None, user_input)
60
- clear_btn.click(lambda: None, None, chatbot)
61
-
62
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import os
3
+ import json
4
+ import uuid
5
+ import tempfile
6
+ import ast
7
+ import math
8
+ import traceback
9
+ from typing import List, Tuple, Dict, Any
10
+
11
  import gradio as gr
12
+ from transformers import AutoTokenizer, AutoModelForCausalLM
13
+ import PyPDF2
14
+ import nltk
15
+
16
+ # Ensure wordnet
17
+ try:
18
+ nltk.data.find("corpora/wordnet")
19
+ except Exception:
20
+ nltk.download("wordnet")
21
+ from nltk.corpus import wordnet
22
+
23
+ # ---------------------------
24
+ # Config
25
+ # ---------------------------
26
+ PRIMARY_MODEL = "microsoft/Phi-3-mini-4k-instruct" # CPU-friendly instruction-tuned model
27
+ FALLBACK_MODEL = "facebook/blenderbot-400M-distill" # small fallback if needed
28
+ MEMORY_FILE = "memory.json"
29
+
30
+ # Ensure memory file
31
+ if not os.path.exists(MEMORY_FILE):
32
+ with open(MEMORY_FILE, "w", encoding="utf-8") as f:
33
+ json.dump({}, f)
34
+
35
+ # ---------------------------
36
+ # Safe model load with fallback
37
+ # ---------------------------
38
+ def safe_load(model_name):
39
+ try:
40
+ tok = AutoTokenizer.from_pretrained(model_name)
41
+ model = AutoModelForCausalLM.from_pretrained(model_name)
42
+ return tok, model, model_name
43
+ except Exception as e:
44
+ print(f"Could not load {model_name}: {e}")
45
+ return None, None, None
46
+
47
+ tokenizer, model, used_model = safe_load(PRIMARY_MODEL)
48
+ if tokenizer is None:
49
+ tokenizer, model, used_model = safe_load(FALLBACK_MODEL)
50
+ if tokenizer is None:
51
+ raise RuntimeError("Failed to load both primary and fallback models. Try switching model names or memory limits.")
52
+
53
+ # ---------------------------
54
+ # Helpers: memory
55
+ # ---------------------------
56
+ def load_memory() -> Dict[str, Any]:
57
+ try:
58
+ with open(MEMORY_FILE, "r", encoding="utf-8") as f:
59
+ return json.load(f)
60
+ except Exception:
61
+ return {}
62
+
63
+ def save_memory(mem: Dict[str, Any]):
64
+ with open(MEMORY_FILE, "w", encoding="utf-8") as f:
65
+ json.dump(mem, f, ensure_ascii=False, indent=2)
66
+
67
+ def get_session(state: dict) -> str:
68
+ sid = state.get("session_id")
69
+ if not sid:
70
+ sid = str(uuid.uuid4())
71
+ state["session_id"] = sid
72
+ mem = load_memory()
73
+ if sid not in mem:
74
+ mem[sid] = {"prefs": {}, "docs": []}
75
+ save_memory(mem)
76
+ return sid
77
+
78
+ # ---------------------------
79
+ # PDF reading
80
+ # ---------------------------
81
+ def extract_text_from_pdf(path: str) -> str:
82
+ try:
83
+ text = []
84
+ with open(path, "rb") as f:
85
+ reader = PyPDF2.PdfReader(f)
86
+ for page in reader.pages:
87
+ page_text = page.extract_text() or ""
88
+ text.append(page_text)
89
+ return "\n".join(text)
90
+ except Exception as e:
91
+ print("PDF read error:", e)
92
+ return ""
93
+
94
+ # ---------------------------
95
+ # Tools
96
+ # ---------------------------
97
+ ALLOWED_MATH = {k: getattr(math, k) for k in dir(math) if not k.startswith("__")}
98
+ ALLOWED_MATH.update({"abs": abs, "round": round})
99
+
100
+ def safe_eval(expr: str):
101
+ try:
102
+ node = ast.parse(expr, mode="eval")
103
+ for n in ast.walk(node):
104
+ if isinstance(n, (ast.Attribute, ast.Lambda, ast.FunctionDef, ast.Import, ast.ImportFrom)):
105
+ raise ValueError("Expression not allowed.")
106
+ code = compile(node, "<string>", "eval")
107
+ return eval(code, {"__builtins__": {}}, ALLOWED_MATH)
108
+ except Exception as e:
109
+ return f"Error: {e}"
110
+
111
+ def define_word(word: str) -> str:
112
+ synsets = wordnet.synsets(word)
113
+ if not synsets:
114
+ return f"No definition found for '{word}'."
115
+ out = []
116
+ for s in synsets[:3]:
117
+ out.append(f"- ({s.lexname()}) {s.definition()}")
118
+ return "\n".join(out)
119
+
120
+ # ---------------------------
121
+ # Prompt building & generation
122
+ # ---------------------------
123
+ def build_context_prompt(session_id: str, user_message: str) -> str:
124
+ mem = load_memory()
125
+ entry = mem.get(session_id, {})
126
+ prefs = entry.get("prefs", {})
127
+ docs = entry.get("docs", [])
128
+ parts = []
129
+ if prefs:
130
+ pref_text = "; ".join(f"{k}: {v}" for k, v in prefs.items() if v)
131
+ if pref_text:
132
+ parts.append(f"User preferences: {pref_text}")
133
+ if docs:
134
+ # include limited doc content
135
+ doc_text = "\n\n".join(docs[-2:])
136
+ parts.append("User documents (context):\n" + doc_text[:3000])
137
+ parts.append(f"User question: {user_message}")
138
+ parts.append("You are a helpful assistant. Answer concisely and clearly. If user asks to 'summarize', 'translate', 'define' or 'calculate', perform that action.")
139
+ return "\n\n".join(parts)
140
+
141
+ def generate_response(prompt: str, max_new_tokens: int = 256, temperature: float = 0.7) -> str:
142
+ try:
143
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=4096)
144
+ outputs = model.generate(
145
+ **inputs,
146
+ max_new_tokens=max_new_tokens,
147
+ do_sample=True,
148
+ temperature=temperature,
149
+ top_p=0.95,
150
+ pad_token_id=tokenizer.eos_token_id
151
+ )
152
+ txt = tokenizer.decode(outputs[0], skip_special_tokens=True)
153
+ # strip prompt if echoed
154
+ if prompt in txt:
155
+ txt = txt.split(prompt, 1)[-1].strip()
156
+ return txt.strip()
157
+ except Exception as e:
158
+ print("Generation error:", e)
159
+ traceback.print_exc()
160
+ return "Sorry — generation failed."
161
+
162
+ # ---------------------------
163
+ # Gradio functions
164
+ # ---------------------------
165
+ def handle_submit(chat_history, message, state):
166
+ if not message:
167
+ return chat_history
168
+ sid = get_session(state)
169
+ lower = message.strip().lower()
170
+
171
+ # tool shortcuts
172
+ if lower.startswith("calc:") or lower.startswith("calculate "):
173
+ expr = message.split(":", 1)[-1] if ":" in message else message.split(None,1)[1]
174
+ res = safe_eval(expr.strip())
175
+ bot = f"Result: {res}"
176
+ chat_history.append((message, bot))
177
+ return chat_history
178
+
179
+ if lower.startswith("define ") or lower.startswith("define:"):
180
+ word = message.split(":",1)[-1] if ":" in message else message.split(None,1)[1]
181
+ bot = define_word(word.strip())
182
+ chat_history.append((message, bot))
183
+ return chat_history
184
+
185
+ if lower.startswith("summarize:") or "summarize my docs" in lower:
186
+ if "summarize my docs" in lower:
187
+ mem = load_memory()
188
+ docs = mem.get(sid, {}).get("docs", [])
189
+ if not docs:
190
+ bot = "No uploaded documents to summarize."
191
+ chat_history.append((message, bot))
192
+ return chat_history
193
+ text = "\n\n".join(docs)
194
+ else:
195
+ text = message.split(":",1)[-1]
196
+ # ask the model to summarize (no extra model)
197
+ prompt = f"Summarize the following text concisely:\n\n{text[:3000]}"
198
+ summary = generate_response(prompt, max_new_tokens=200, temperature=0.3)
199
+ bot = "Summary:\n" + summary
200
+ chat_history.append((message, bot))
201
+ return chat_history
202
+
203
+ if lower.startswith("translate"):
204
+ # use model to translate; simple parse: "translate to <lang>: text"
205
+ parts = message.split(":",1)
206
+ if len(parts) == 2 and "to " in parts[0].lower():
207
+ tgt = parts[0].lower().split("to",1)[-1].strip()
208
+ text = parts[1].strip()
209
+ prompt = f"Translate the following text to {tgt}:\n\n{text}"
210
+ else:
211
+ # fallback translate whole message to English
212
+ text = message.split(":",1)[-1] if ":" in message else message
213
+ prompt = f"Translate the following text to English:\n\n{text}"
214
+ translated = generate_response(prompt, max_new_tokens=200, temperature=0.3)
215
+ bot = "Translation:\n" + translated
216
+ chat_history.append((message, bot))
217
+ return chat_history
218
+
219
+ # standard conversational flow
220
+ system_prompt = build_context_prompt(sid, message)
221
+ reply = generate_response(system_prompt, max_new_tokens=300, temperature=0.7)
222
+
223
+ # light memory heuristics: save "my name is X" or "i prefer X"
224
+ try:
225
+ low = message.lower()
226
+ mem = load_memory()
227
+ if "my name is " in low:
228
+ name = message.split("my name is",1)[1].strip().split()[0]
229
+ mem[sid]["prefs"]["name"] = name
230
+ save_memory(mem)
231
+ if any(k in low for k in ["i prefer", "i like", "i'm a", "i am a"]):
232
+ pref_key = f"pref_{len(mem[sid].get('prefs',{}))+1}"
233
+ mem[sid]["prefs"][pref_key] = message
234
+ save_memory(mem)
235
+ except Exception as e:
236
+ print("Memory write failed:", e)
237
+
238
+ chat_history.append((message, reply))
239
+ return chat_history
240
+
241
+ def upload_pdf(file, state):
242
+ if not file:
243
+ return "No file uploaded."
244
+ sid = get_session(state)
245
+ # file may be a temp file path or file-like; Gradio usually gives a dict-like with .name
246
+ path = file.name if hasattr(file, "name") else file
247
+ text = extract_text_from_pdf(path)
248
+ mem = load_memory()
249
+ mem[sid]["docs"].append(text[:20000])
250
+ save_memory(mem)
251
+ return "PDF uploaded and indexed into session memory."
252
+
253
+ def show_memory(state):
254
+ sid = get_session(state)
255
+ mem = load_memory()
256
+ return json.dumps(mem.get(sid, {}), ensure_ascii=False, indent=2)
257
+
258
+ def reset_memory(state):
259
+ sid = get_session(state)
260
+ mem = load_memory()
261
+ mem[sid] = {"prefs": {}, "docs": []}
262
+ save_memory(mem)
263
+ return "Session memory reset."
264
+
265
+ # ---------------------------
266
+ # UI (creative but lightweight)
267
+ # ---------------------------
268
+ with gr.Blocks(theme=gr.themes.Soft(primary_hue="purple", secondary_hue="blue")) as demo:
269
+ gr.Markdown(f"# 🤖 GPT-Lite Assistant — {used_model}\nLightweight CPU-ready assistant with memory, PDF reading & tools.")
270
+ with gr.Row():
271
+ with gr.Column(scale=3):
272
+ chatbot = gr.Chatbot(label="Assistant", height=520)
273
+ with gr.Row():
274
+ txt = gr.Textbox(show_label=False, placeholder="Ask anything (or use commands: calc:, define:, summarize:, translate: )")
275
+ send = gr.Button("Send")
276
+ with gr.Row():
277
+ pdf_file = gr.File(label="Upload PDF (optional)", file_types=[".pdf"])
278
+ upload_btn = gr.Button("Upload PDF")
279
+ with gr.Row():
280
+ show_mem_btn = gr.Button("Show session memory")
281
+ reset_mem_btn = gr.Button("Reset memory")
282
+ with gr.Column(scale=1):
283
+ gr.Markdown("### Quick examples\n- Explain photosynthesis\n- calc: 12/3 + 4\n- define: gravity\n- translate to es: How are you?\n- summarize my docs")
284
+ gr.Markdown("### Notes\n- Model runs on CPU. If Space hits memory limits, switch PRIMARY_MODEL to a smaller model.")
285
+ state = gr.State({})
286
+
287
+ send.click(handle_submit, [chatbot, txt, state], chatbot)
288
+ txt.submit(handle_submit, [chatbot, txt, state], chatbot)
289
+ upload_btn.click(upload_pdf, [pdf_file, state], gr.Textbox())
290
+ show_mem_btn.click(show_memory, [state], gr.Textbox())
291
+ reset_mem_btn.click(reset_memory, [state], gr.Textbox())
292
+
293
+ demo.launch(server_name="0.0.0.0", server_port=7860)