simran40's picture
Update app.py
e94cf6c verified
# app.py
import os
import json
import uuid
import tempfile
import ast
import math
import traceback
from typing import List, Tuple, Dict, Any
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import PyPDF2
import nltk
# Ensure wordnet
try:
nltk.data.find("corpora/wordnet")
except Exception:
nltk.download("wordnet")
from nltk.corpus import wordnet
# ---------------------------
# Config
# ---------------------------
PRIMARY_MODEL = "microsoft/Phi-3-mini-4k-instruct" # CPU-friendly instruction-tuned model
FALLBACK_MODEL = "facebook/blenderbot-400M-distill" # small fallback if needed
MEMORY_FILE = "memory.json"
# Ensure memory file
if not os.path.exists(MEMORY_FILE):
with open(MEMORY_FILE, "w", encoding="utf-8") as f:
json.dump({}, f)
# ---------------------------
# Safe model load with fallback
# ---------------------------
def safe_load(model_name):
try:
tok = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
return tok, model, model_name
except Exception as e:
print(f"Could not load {model_name}: {e}")
return None, None, None
tokenizer, model, used_model = safe_load(PRIMARY_MODEL)
if tokenizer is None:
tokenizer, model, used_model = safe_load(FALLBACK_MODEL)
if tokenizer is None:
raise RuntimeError("Failed to load both primary and fallback models. Try switching model names or memory limits.")
# ---------------------------
# Helpers: memory
# ---------------------------
def load_memory() -> Dict[str, Any]:
try:
with open(MEMORY_FILE, "r", encoding="utf-8") as f:
return json.load(f)
except Exception:
return {}
def save_memory(mem: Dict[str, Any]):
with open(MEMORY_FILE, "w", encoding="utf-8") as f:
json.dump(mem, f, ensure_ascii=False, indent=2)
def get_session(state: dict) -> str:
sid = state.get("session_id")
if not sid:
sid = str(uuid.uuid4())
state["session_id"] = sid
mem = load_memory()
if sid not in mem:
mem[sid] = {"prefs": {}, "docs": []}
save_memory(mem)
return sid
# ---------------------------
# PDF reading
# ---------------------------
def extract_text_from_pdf(path: str) -> str:
try:
text = []
with open(path, "rb") as f:
reader = PyPDF2.PdfReader(f)
for page in reader.pages:
page_text = page.extract_text() or ""
text.append(page_text)
return "\n".join(text)
except Exception as e:
print("PDF read error:", e)
return ""
# ---------------------------
# Tools
# ---------------------------
ALLOWED_MATH = {k: getattr(math, k) for k in dir(math) if not k.startswith("__")}
ALLOWED_MATH.update({"abs": abs, "round": round})
def safe_eval(expr: str):
try:
node = ast.parse(expr, mode="eval")
for n in ast.walk(node):
if isinstance(n, (ast.Attribute, ast.Lambda, ast.FunctionDef, ast.Import, ast.ImportFrom)):
raise ValueError("Expression not allowed.")
code = compile(node, "<string>", "eval")
return eval(code, {"__builtins__": {}}, ALLOWED_MATH)
except Exception as e:
return f"Error: {e}"
def define_word(word: str) -> str:
synsets = wordnet.synsets(word)
if not synsets:
return f"No definition found for '{word}'."
out = []
for s in synsets[:3]:
out.append(f"- ({s.lexname()}) {s.definition()}")
return "\n".join(out)
# ---------------------------
# Prompt building & generation
# ---------------------------
def build_context_prompt(session_id: str, user_message: str) -> str:
mem = load_memory()
entry = mem.get(session_id, {})
prefs = entry.get("prefs", {})
docs = entry.get("docs", [])
parts = []
if prefs:
pref_text = "; ".join(f"{k}: {v}" for k, v in prefs.items() if v)
if pref_text:
parts.append(f"User preferences: {pref_text}")
if docs:
# include limited doc content
doc_text = "\n\n".join(docs[-2:])
parts.append("User documents (context):\n" + doc_text[:3000])
parts.append(f"User question: {user_message}")
parts.append("You are a helpful assistant. Answer concisely and clearly. If user asks to 'summarize', 'translate', 'define' or 'calculate', perform that action.")
return "\n\n".join(parts)
def generate_response(prompt: str, max_new_tokens: int = 256, temperature: float = 0.7) -> str:
try:
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=4096)
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
top_p=0.95,
pad_token_id=tokenizer.eos_token_id
)
txt = tokenizer.decode(outputs[0], skip_special_tokens=True)
# strip prompt if echoed
if prompt in txt:
txt = txt.split(prompt, 1)[-1].strip()
return txt.strip()
except Exception as e:
print("Generation error:", e)
traceback.print_exc()
return "Sorry — generation failed."
# ---------------------------
# Gradio functions
# ---------------------------
def handle_submit(chat_history, message, state):
if not message:
return chat_history
sid = get_session(state)
lower = message.strip().lower()
# tool shortcuts
if lower.startswith("calc:") or lower.startswith("calculate "):
expr = message.split(":", 1)[-1] if ":" in message else message.split(None,1)[1]
res = safe_eval(expr.strip())
bot = f"Result: {res}"
chat_history.append((message, bot))
return chat_history
if lower.startswith("define ") or lower.startswith("define:"):
word = message.split(":",1)[-1] if ":" in message else message.split(None,1)[1]
bot = define_word(word.strip())
chat_history.append((message, bot))
return chat_history
if lower.startswith("summarize:") or "summarize my docs" in lower:
if "summarize my docs" in lower:
mem = load_memory()
docs = mem.get(sid, {}).get("docs", [])
if not docs:
bot = "No uploaded documents to summarize."
chat_history.append((message, bot))
return chat_history
text = "\n\n".join(docs)
else:
text = message.split(":",1)[-1]
# ask the model to summarize (no extra model)
prompt = f"Summarize the following text concisely:\n\n{text[:3000]}"
summary = generate_response(prompt, max_new_tokens=200, temperature=0.3)
bot = "Summary:\n" + summary
chat_history.append((message, bot))
return chat_history
if lower.startswith("translate"):
# use model to translate; simple parse: "translate to <lang>: text"
parts = message.split(":",1)
if len(parts) == 2 and "to " in parts[0].lower():
tgt = parts[0].lower().split("to",1)[-1].strip()
text = parts[1].strip()
prompt = f"Translate the following text to {tgt}:\n\n{text}"
else:
# fallback translate whole message to English
text = message.split(":",1)[-1] if ":" in message else message
prompt = f"Translate the following text to English:\n\n{text}"
translated = generate_response(prompt, max_new_tokens=200, temperature=0.3)
bot = "Translation:\n" + translated
chat_history.append((message, bot))
return chat_history
# standard conversational flow
system_prompt = build_context_prompt(sid, message)
reply = generate_response(system_prompt, max_new_tokens=300, temperature=0.7)
# light memory heuristics: save "my name is X" or "i prefer X"
try:
low = message.lower()
mem = load_memory()
if "my name is " in low:
name = message.split("my name is",1)[1].strip().split()[0]
mem[sid]["prefs"]["name"] = name
save_memory(mem)
if any(k in low for k in ["i prefer", "i like", "i'm a", "i am a"]):
pref_key = f"pref_{len(mem[sid].get('prefs',{}))+1}"
mem[sid]["prefs"][pref_key] = message
save_memory(mem)
except Exception as e:
print("Memory write failed:", e)
chat_history.append((message, reply))
return chat_history
def upload_pdf(file, state):
if not file:
return "No file uploaded."
sid = get_session(state)
# file may be a temp file path or file-like; Gradio usually gives a dict-like with .name
path = file.name if hasattr(file, "name") else file
text = extract_text_from_pdf(path)
mem = load_memory()
mem[sid]["docs"].append(text[:20000])
save_memory(mem)
return "PDF uploaded and indexed into session memory."
def show_memory(state):
sid = get_session(state)
mem = load_memory()
return json.dumps(mem.get(sid, {}), ensure_ascii=False, indent=2)
def reset_memory(state):
sid = get_session(state)
mem = load_memory()
mem[sid] = {"prefs": {}, "docs": []}
save_memory(mem)
return "Session memory reset."
# ---------------------------
# UI (creative but lightweight)
# ---------------------------
with gr.Blocks(theme=gr.themes.Soft(primary_hue="purple", secondary_hue="blue")) as demo:
gr.Markdown(f"# 🤖 GPT-Lite Assistant — {used_model}\nLightweight CPU-ready assistant with memory, PDF reading & tools.")
with gr.Row():
with gr.Column(scale=3):
chatbot = gr.Chatbot(label="Assistant", height=520)
with gr.Row():
txt = gr.Textbox(show_label=False, placeholder="Ask anything (or use commands: calc:, define:, summarize:, translate: )")
send = gr.Button("Send")
with gr.Row():
pdf_file = gr.File(label="Upload PDF (optional)", file_types=[".pdf"])
upload_btn = gr.Button("Upload PDF")
with gr.Row():
show_mem_btn = gr.Button("Show session memory")
reset_mem_btn = gr.Button("Reset memory")
with gr.Column(scale=1):
gr.Markdown("### Quick examples\n- Explain photosynthesis\n- calc: 12/3 + 4\n- define: gravity\n- translate to es: How are you?\n- summarize my docs")
gr.Markdown("### Notes\n- Model runs on CPU. If Space hits memory limits, switch PRIMARY_MODEL to a smaller model.")
state = gr.State({})
send.click(handle_submit, [chatbot, txt, state], chatbot)
txt.submit(handle_submit, [chatbot, txt, state], chatbot)
upload_btn.click(upload_pdf, [pdf_file, state], gr.Textbox())
show_mem_btn.click(show_memory, [state], gr.Textbox())
reset_mem_btn.click(reset_memory, [state], gr.Textbox())
demo.launch(server_name="0.0.0.0", server_port=7860)