Spaces:
Sleeping
Sleeping
| # app.py | |
| import os | |
| import json | |
| import uuid | |
| import tempfile | |
| import ast | |
| import math | |
| import traceback | |
| from typing import List, Tuple, Dict, Any | |
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import PyPDF2 | |
| import nltk | |
| # Ensure wordnet | |
| try: | |
| nltk.data.find("corpora/wordnet") | |
| except Exception: | |
| nltk.download("wordnet") | |
| from nltk.corpus import wordnet | |
| # --------------------------- | |
| # Config | |
| # --------------------------- | |
| PRIMARY_MODEL = "microsoft/Phi-3-mini-4k-instruct" # CPU-friendly instruction-tuned model | |
| FALLBACK_MODEL = "facebook/blenderbot-400M-distill" # small fallback if needed | |
| MEMORY_FILE = "memory.json" | |
| # Ensure memory file | |
| if not os.path.exists(MEMORY_FILE): | |
| with open(MEMORY_FILE, "w", encoding="utf-8") as f: | |
| json.dump({}, f) | |
| # --------------------------- | |
| # Safe model load with fallback | |
| # --------------------------- | |
| def safe_load(model_name): | |
| try: | |
| tok = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForCausalLM.from_pretrained(model_name) | |
| return tok, model, model_name | |
| except Exception as e: | |
| print(f"Could not load {model_name}: {e}") | |
| return None, None, None | |
| tokenizer, model, used_model = safe_load(PRIMARY_MODEL) | |
| if tokenizer is None: | |
| tokenizer, model, used_model = safe_load(FALLBACK_MODEL) | |
| if tokenizer is None: | |
| raise RuntimeError("Failed to load both primary and fallback models. Try switching model names or memory limits.") | |
| # --------------------------- | |
| # Helpers: memory | |
| # --------------------------- | |
| def load_memory() -> Dict[str, Any]: | |
| try: | |
| with open(MEMORY_FILE, "r", encoding="utf-8") as f: | |
| return json.load(f) | |
| except Exception: | |
| return {} | |
| def save_memory(mem: Dict[str, Any]): | |
| with open(MEMORY_FILE, "w", encoding="utf-8") as f: | |
| json.dump(mem, f, ensure_ascii=False, indent=2) | |
| def get_session(state: dict) -> str: | |
| sid = state.get("session_id") | |
| if not sid: | |
| sid = str(uuid.uuid4()) | |
| state["session_id"] = sid | |
| mem = load_memory() | |
| if sid not in mem: | |
| mem[sid] = {"prefs": {}, "docs": []} | |
| save_memory(mem) | |
| return sid | |
| # --------------------------- | |
| # PDF reading | |
| # --------------------------- | |
| def extract_text_from_pdf(path: str) -> str: | |
| try: | |
| text = [] | |
| with open(path, "rb") as f: | |
| reader = PyPDF2.PdfReader(f) | |
| for page in reader.pages: | |
| page_text = page.extract_text() or "" | |
| text.append(page_text) | |
| return "\n".join(text) | |
| except Exception as e: | |
| print("PDF read error:", e) | |
| return "" | |
| # --------------------------- | |
| # Tools | |
| # --------------------------- | |
| ALLOWED_MATH = {k: getattr(math, k) for k in dir(math) if not k.startswith("__")} | |
| ALLOWED_MATH.update({"abs": abs, "round": round}) | |
| def safe_eval(expr: str): | |
| try: | |
| node = ast.parse(expr, mode="eval") | |
| for n in ast.walk(node): | |
| if isinstance(n, (ast.Attribute, ast.Lambda, ast.FunctionDef, ast.Import, ast.ImportFrom)): | |
| raise ValueError("Expression not allowed.") | |
| code = compile(node, "<string>", "eval") | |
| return eval(code, {"__builtins__": {}}, ALLOWED_MATH) | |
| except Exception as e: | |
| return f"Error: {e}" | |
| def define_word(word: str) -> str: | |
| synsets = wordnet.synsets(word) | |
| if not synsets: | |
| return f"No definition found for '{word}'." | |
| out = [] | |
| for s in synsets[:3]: | |
| out.append(f"- ({s.lexname()}) {s.definition()}") | |
| return "\n".join(out) | |
| # --------------------------- | |
| # Prompt building & generation | |
| # --------------------------- | |
| def build_context_prompt(session_id: str, user_message: str) -> str: | |
| mem = load_memory() | |
| entry = mem.get(session_id, {}) | |
| prefs = entry.get("prefs", {}) | |
| docs = entry.get("docs", []) | |
| parts = [] | |
| if prefs: | |
| pref_text = "; ".join(f"{k}: {v}" for k, v in prefs.items() if v) | |
| if pref_text: | |
| parts.append(f"User preferences: {pref_text}") | |
| if docs: | |
| # include limited doc content | |
| doc_text = "\n\n".join(docs[-2:]) | |
| parts.append("User documents (context):\n" + doc_text[:3000]) | |
| parts.append(f"User question: {user_message}") | |
| parts.append("You are a helpful assistant. Answer concisely and clearly. If user asks to 'summarize', 'translate', 'define' or 'calculate', perform that action.") | |
| return "\n\n".join(parts) | |
| def generate_response(prompt: str, max_new_tokens: int = 256, temperature: float = 0.7) -> str: | |
| try: | |
| inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=4096) | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=True, | |
| temperature=temperature, | |
| top_p=0.95, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| txt = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # strip prompt if echoed | |
| if prompt in txt: | |
| txt = txt.split(prompt, 1)[-1].strip() | |
| return txt.strip() | |
| except Exception as e: | |
| print("Generation error:", e) | |
| traceback.print_exc() | |
| return "Sorry — generation failed." | |
| # --------------------------- | |
| # Gradio functions | |
| # --------------------------- | |
| def handle_submit(chat_history, message, state): | |
| if not message: | |
| return chat_history | |
| sid = get_session(state) | |
| lower = message.strip().lower() | |
| # tool shortcuts | |
| if lower.startswith("calc:") or lower.startswith("calculate "): | |
| expr = message.split(":", 1)[-1] if ":" in message else message.split(None,1)[1] | |
| res = safe_eval(expr.strip()) | |
| bot = f"Result: {res}" | |
| chat_history.append((message, bot)) | |
| return chat_history | |
| if lower.startswith("define ") or lower.startswith("define:"): | |
| word = message.split(":",1)[-1] if ":" in message else message.split(None,1)[1] | |
| bot = define_word(word.strip()) | |
| chat_history.append((message, bot)) | |
| return chat_history | |
| if lower.startswith("summarize:") or "summarize my docs" in lower: | |
| if "summarize my docs" in lower: | |
| mem = load_memory() | |
| docs = mem.get(sid, {}).get("docs", []) | |
| if not docs: | |
| bot = "No uploaded documents to summarize." | |
| chat_history.append((message, bot)) | |
| return chat_history | |
| text = "\n\n".join(docs) | |
| else: | |
| text = message.split(":",1)[-1] | |
| # ask the model to summarize (no extra model) | |
| prompt = f"Summarize the following text concisely:\n\n{text[:3000]}" | |
| summary = generate_response(prompt, max_new_tokens=200, temperature=0.3) | |
| bot = "Summary:\n" + summary | |
| chat_history.append((message, bot)) | |
| return chat_history | |
| if lower.startswith("translate"): | |
| # use model to translate; simple parse: "translate to <lang>: text" | |
| parts = message.split(":",1) | |
| if len(parts) == 2 and "to " in parts[0].lower(): | |
| tgt = parts[0].lower().split("to",1)[-1].strip() | |
| text = parts[1].strip() | |
| prompt = f"Translate the following text to {tgt}:\n\n{text}" | |
| else: | |
| # fallback translate whole message to English | |
| text = message.split(":",1)[-1] if ":" in message else message | |
| prompt = f"Translate the following text to English:\n\n{text}" | |
| translated = generate_response(prompt, max_new_tokens=200, temperature=0.3) | |
| bot = "Translation:\n" + translated | |
| chat_history.append((message, bot)) | |
| return chat_history | |
| # standard conversational flow | |
| system_prompt = build_context_prompt(sid, message) | |
| reply = generate_response(system_prompt, max_new_tokens=300, temperature=0.7) | |
| # light memory heuristics: save "my name is X" or "i prefer X" | |
| try: | |
| low = message.lower() | |
| mem = load_memory() | |
| if "my name is " in low: | |
| name = message.split("my name is",1)[1].strip().split()[0] | |
| mem[sid]["prefs"]["name"] = name | |
| save_memory(mem) | |
| if any(k in low for k in ["i prefer", "i like", "i'm a", "i am a"]): | |
| pref_key = f"pref_{len(mem[sid].get('prefs',{}))+1}" | |
| mem[sid]["prefs"][pref_key] = message | |
| save_memory(mem) | |
| except Exception as e: | |
| print("Memory write failed:", e) | |
| chat_history.append((message, reply)) | |
| return chat_history | |
| def upload_pdf(file, state): | |
| if not file: | |
| return "No file uploaded." | |
| sid = get_session(state) | |
| # file may be a temp file path or file-like; Gradio usually gives a dict-like with .name | |
| path = file.name if hasattr(file, "name") else file | |
| text = extract_text_from_pdf(path) | |
| mem = load_memory() | |
| mem[sid]["docs"].append(text[:20000]) | |
| save_memory(mem) | |
| return "PDF uploaded and indexed into session memory." | |
| def show_memory(state): | |
| sid = get_session(state) | |
| mem = load_memory() | |
| return json.dumps(mem.get(sid, {}), ensure_ascii=False, indent=2) | |
| def reset_memory(state): | |
| sid = get_session(state) | |
| mem = load_memory() | |
| mem[sid] = {"prefs": {}, "docs": []} | |
| save_memory(mem) | |
| return "Session memory reset." | |
| # --------------------------- | |
| # UI (creative but lightweight) | |
| # --------------------------- | |
| with gr.Blocks(theme=gr.themes.Soft(primary_hue="purple", secondary_hue="blue")) as demo: | |
| gr.Markdown(f"# 🤖 GPT-Lite Assistant — {used_model}\nLightweight CPU-ready assistant with memory, PDF reading & tools.") | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| chatbot = gr.Chatbot(label="Assistant", height=520) | |
| with gr.Row(): | |
| txt = gr.Textbox(show_label=False, placeholder="Ask anything (or use commands: calc:, define:, summarize:, translate: )") | |
| send = gr.Button("Send") | |
| with gr.Row(): | |
| pdf_file = gr.File(label="Upload PDF (optional)", file_types=[".pdf"]) | |
| upload_btn = gr.Button("Upload PDF") | |
| with gr.Row(): | |
| show_mem_btn = gr.Button("Show session memory") | |
| reset_mem_btn = gr.Button("Reset memory") | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Quick examples\n- Explain photosynthesis\n- calc: 12/3 + 4\n- define: gravity\n- translate to es: How are you?\n- summarize my docs") | |
| gr.Markdown("### Notes\n- Model runs on CPU. If Space hits memory limits, switch PRIMARY_MODEL to a smaller model.") | |
| state = gr.State({}) | |
| send.click(handle_submit, [chatbot, txt, state], chatbot) | |
| txt.submit(handle_submit, [chatbot, txt, state], chatbot) | |
| upload_btn.click(upload_pdf, [pdf_file, state], gr.Textbox()) | |
| show_mem_btn.click(show_memory, [state], gr.Textbox()) | |
| reset_mem_btn.click(reset_memory, [state], gr.Textbox()) | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |