File size: 10,917 Bytes
e94cf6c
 
 
 
 
 
 
 
 
 
a1813c2
e94cf6c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
# app.py
import os
import json
import uuid
import tempfile
import ast
import math
import traceback
from typing import List, Tuple, Dict, Any

import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import PyPDF2
import nltk

# Ensure wordnet
try:
    nltk.data.find("corpora/wordnet")
except Exception:
    nltk.download("wordnet")
from nltk.corpus import wordnet

# ---------------------------
# Config
# ---------------------------
PRIMARY_MODEL = "microsoft/Phi-3-mini-4k-instruct"   # CPU-friendly instruction-tuned model
FALLBACK_MODEL = "facebook/blenderbot-400M-distill" # small fallback if needed
MEMORY_FILE = "memory.json"

# Ensure memory file
if not os.path.exists(MEMORY_FILE):
    with open(MEMORY_FILE, "w", encoding="utf-8") as f:
        json.dump({}, f)

# ---------------------------
# Safe model load with fallback
# ---------------------------
def safe_load(model_name):
    try:
        tok = AutoTokenizer.from_pretrained(model_name)
        model = AutoModelForCausalLM.from_pretrained(model_name)
        return tok, model, model_name
    except Exception as e:
        print(f"Could not load {model_name}: {e}")
        return None, None, None

tokenizer, model, used_model = safe_load(PRIMARY_MODEL)
if tokenizer is None:
    tokenizer, model, used_model = safe_load(FALLBACK_MODEL)
if tokenizer is None:
    raise RuntimeError("Failed to load both primary and fallback models. Try switching model names or memory limits.")

# ---------------------------
# Helpers: memory
# ---------------------------
def load_memory() -> Dict[str, Any]:
    try:
        with open(MEMORY_FILE, "r", encoding="utf-8") as f:
            return json.load(f)
    except Exception:
        return {}

def save_memory(mem: Dict[str, Any]):
    with open(MEMORY_FILE, "w", encoding="utf-8") as f:
        json.dump(mem, f, ensure_ascii=False, indent=2)

def get_session(state: dict) -> str:
    sid = state.get("session_id")
    if not sid:
        sid = str(uuid.uuid4())
        state["session_id"] = sid
    mem = load_memory()
    if sid not in mem:
        mem[sid] = {"prefs": {}, "docs": []}
        save_memory(mem)
    return sid

# ---------------------------
# PDF reading
# ---------------------------
def extract_text_from_pdf(path: str) -> str:
    try:
        text = []
        with open(path, "rb") as f:
            reader = PyPDF2.PdfReader(f)
            for page in reader.pages:
                page_text = page.extract_text() or ""
                text.append(page_text)
        return "\n".join(text)
    except Exception as e:
        print("PDF read error:", e)
        return ""

# ---------------------------
# Tools
# ---------------------------
ALLOWED_MATH = {k: getattr(math, k) for k in dir(math) if not k.startswith("__")}
ALLOWED_MATH.update({"abs": abs, "round": round})

def safe_eval(expr: str):
    try:
        node = ast.parse(expr, mode="eval")
        for n in ast.walk(node):
            if isinstance(n, (ast.Attribute, ast.Lambda, ast.FunctionDef, ast.Import, ast.ImportFrom)):
                raise ValueError("Expression not allowed.")
        code = compile(node, "<string>", "eval")
        return eval(code, {"__builtins__": {}}, ALLOWED_MATH)
    except Exception as e:
        return f"Error: {e}"

def define_word(word: str) -> str:
    synsets = wordnet.synsets(word)
    if not synsets:
        return f"No definition found for '{word}'."
    out = []
    for s in synsets[:3]:
        out.append(f"- ({s.lexname()}) {s.definition()}")
    return "\n".join(out)

# ---------------------------
# Prompt building & generation
# ---------------------------
def build_context_prompt(session_id: str, user_message: str) -> str:
    mem = load_memory()
    entry = mem.get(session_id, {})
    prefs = entry.get("prefs", {})
    docs = entry.get("docs", [])
    parts = []
    if prefs:
        pref_text = "; ".join(f"{k}: {v}" for k, v in prefs.items() if v)
        if pref_text:
            parts.append(f"User preferences: {pref_text}")
    if docs:
        # include limited doc content
        doc_text = "\n\n".join(docs[-2:])
        parts.append("User documents (context):\n" + doc_text[:3000])
    parts.append(f"User question: {user_message}")
    parts.append("You are a helpful assistant. Answer concisely and clearly. If user asks to 'summarize', 'translate', 'define' or 'calculate', perform that action.")
    return "\n\n".join(parts)

def generate_response(prompt: str, max_new_tokens: int = 256, temperature: float = 0.7) -> str:
    try:
        inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=4096)
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            temperature=temperature,
            top_p=0.95,
            pad_token_id=tokenizer.eos_token_id
        )
        txt = tokenizer.decode(outputs[0], skip_special_tokens=True)
        # strip prompt if echoed
        if prompt in txt:
            txt = txt.split(prompt, 1)[-1].strip()
        return txt.strip()
    except Exception as e:
        print("Generation error:", e)
        traceback.print_exc()
        return "Sorry — generation failed."

# ---------------------------
# Gradio functions
# ---------------------------
def handle_submit(chat_history, message, state):
    if not message:
        return chat_history
    sid = get_session(state)
    lower = message.strip().lower()

    # tool shortcuts
    if lower.startswith("calc:") or lower.startswith("calculate "):
        expr = message.split(":", 1)[-1] if ":" in message else message.split(None,1)[1]
        res = safe_eval(expr.strip())
        bot = f"Result: {res}"
        chat_history.append((message, bot))
        return chat_history

    if lower.startswith("define ") or lower.startswith("define:"):
        word = message.split(":",1)[-1] if ":" in message else message.split(None,1)[1]
        bot = define_word(word.strip())
        chat_history.append((message, bot))
        return chat_history

    if lower.startswith("summarize:") or "summarize my docs" in lower:
        if "summarize my docs" in lower:
            mem = load_memory()
            docs = mem.get(sid, {}).get("docs", [])
            if not docs:
                bot = "No uploaded documents to summarize."
                chat_history.append((message, bot))
                return chat_history
            text = "\n\n".join(docs)
        else:
            text = message.split(":",1)[-1]
        # ask the model to summarize (no extra model)
        prompt = f"Summarize the following text concisely:\n\n{text[:3000]}"
        summary = generate_response(prompt, max_new_tokens=200, temperature=0.3)
        bot = "Summary:\n" + summary
        chat_history.append((message, bot))
        return chat_history

    if lower.startswith("translate"):
        # use model to translate; simple parse: "translate to <lang>: text"
        parts = message.split(":",1)
        if len(parts) == 2 and "to " in parts[0].lower():
            tgt = parts[0].lower().split("to",1)[-1].strip()
            text = parts[1].strip()
            prompt = f"Translate the following text to {tgt}:\n\n{text}"
        else:
            # fallback translate whole message to English
            text = message.split(":",1)[-1] if ":" in message else message
            prompt = f"Translate the following text to English:\n\n{text}"
        translated = generate_response(prompt, max_new_tokens=200, temperature=0.3)
        bot = "Translation:\n" + translated
        chat_history.append((message, bot))
        return chat_history

    # standard conversational flow
    system_prompt = build_context_prompt(sid, message)
    reply = generate_response(system_prompt, max_new_tokens=300, temperature=0.7)

    # light memory heuristics: save "my name is X" or "i prefer X"
    try:
        low = message.lower()
        mem = load_memory()
        if "my name is " in low:
            name = message.split("my name is",1)[1].strip().split()[0]
            mem[sid]["prefs"]["name"] = name
            save_memory(mem)
        if any(k in low for k in ["i prefer", "i like", "i'm a", "i am a"]):
            pref_key = f"pref_{len(mem[sid].get('prefs',{}))+1}"
            mem[sid]["prefs"][pref_key] = message
            save_memory(mem)
    except Exception as e:
        print("Memory write failed:", e)

    chat_history.append((message, reply))
    return chat_history

def upload_pdf(file, state):
    if not file:
        return "No file uploaded."
    sid = get_session(state)
    # file may be a temp file path or file-like; Gradio usually gives a dict-like with .name
    path = file.name if hasattr(file, "name") else file
    text = extract_text_from_pdf(path)
    mem = load_memory()
    mem[sid]["docs"].append(text[:20000])
    save_memory(mem)
    return "PDF uploaded and indexed into session memory."

def show_memory(state):
    sid = get_session(state)
    mem = load_memory()
    return json.dumps(mem.get(sid, {}), ensure_ascii=False, indent=2)

def reset_memory(state):
    sid = get_session(state)
    mem = load_memory()
    mem[sid] = {"prefs": {}, "docs": []}
    save_memory(mem)
    return "Session memory reset."

# ---------------------------
# UI (creative but lightweight)
# ---------------------------
with gr.Blocks(theme=gr.themes.Soft(primary_hue="purple", secondary_hue="blue")) as demo:
    gr.Markdown(f"# 🤖 GPT-Lite Assistant — {used_model}\nLightweight CPU-ready assistant with memory, PDF reading & tools.")
    with gr.Row():
        with gr.Column(scale=3):
            chatbot = gr.Chatbot(label="Assistant", height=520)
            with gr.Row():
                txt = gr.Textbox(show_label=False, placeholder="Ask anything (or use commands: calc:, define:, summarize:, translate: )")
                send = gr.Button("Send")
            with gr.Row():
                pdf_file = gr.File(label="Upload PDF (optional)", file_types=[".pdf"])
                upload_btn = gr.Button("Upload PDF")
            with gr.Row():
                show_mem_btn = gr.Button("Show session memory")
                reset_mem_btn = gr.Button("Reset memory")
        with gr.Column(scale=1):
            gr.Markdown("### Quick examples\n- Explain photosynthesis\n- calc: 12/3 + 4\n- define: gravity\n- translate to es: How are you?\n- summarize my docs")
            gr.Markdown("### Notes\n- Model runs on CPU. If Space hits memory limits, switch PRIMARY_MODEL to a smaller model.")
    state = gr.State({})

    send.click(handle_submit, [chatbot, txt, state], chatbot)
    txt.submit(handle_submit, [chatbot, txt, state], chatbot)
    upload_btn.click(upload_pdf, [pdf_file, state], gr.Textbox())
    show_mem_btn.click(show_memory, [state], gr.Textbox())
    reset_mem_btn.click(reset_memory, [state], gr.Textbox())

demo.launch(server_name="0.0.0.0", server_port=7860)