# app.py โ€” full version with memory + web search + datasets import os import json import threading import gradio as gr from huggingface_hub import InferenceClient, snapshot_download from datasets import load_dataset from duckduckgo_search import DDGS # ---------------- CONFIG ---------------- MODEL_ID = "openai/gpt-oss-120b" # or granite DATA_DIR = "/data" if os.path.isdir("/data") else "./data" os.makedirs(DATA_DIR, exist_ok=True) SHORT_TERM_LIMIT = 10 SUMMARY_MAX_TOKENS = 150 MEMORY_LOCK = threading.Lock() # ---------------- dataset loading ---------------- # โš ๏ธ Heavy startup, comment out if running on free HF Space folder = snapshot_download( "HuggingFaceFW/fineweb", repo_type="dataset", local_dir="./fineweb/", allow_patterns="sample/10BT/*", ) ds1 = load_dataset("HuggingFaceH4/ultrachat_200k") ds2 = load_dataset("Anthropic/hh-rlhf") # ---------------- helpers: memory ---------------- def get_user_id(hf_token: gr.OAuthToken | None): if hf_token and getattr(hf_token, "token", None): return "user_" + hf_token.token[:12] return "anon" def memory_file_path(user_id: str): return os.path.join(DATA_DIR, f"memory_{user_id}.json") def load_memory(user_id: str): p = memory_file_path(user_id) if os.path.exists(p): try: with open(p, "r", encoding="utf-8") as f: mem = json.load(f) if isinstance(mem, dict) and "short_term" in mem and "long_term" in mem: return mem except Exception as e: print("load_memory error:", e) return {"short_term": [], "long_term": ""} def save_memory(user_id: str, memory: dict): p = memory_file_path(user_id) try: with MEMORY_LOCK: with open(p, "w", encoding="utf-8") as f: json.dump(memory, f, ensure_ascii=False, indent=2) except Exception as e: print("save_memory error:", e) # ---------------- normalize history ---------------- def normalize_history(history): out = [] if not history: return out for turn in history: if isinstance(turn, dict) and "role" in turn and "content" in turn: out.append({"role": turn["role"], "content": str(turn["content"])}) elif isinstance(turn, (list, tuple)) and len(turn) == 2: user_msg, assistant_msg = turn out.append({"role": "user", "content": str(user_msg)}) out.append({"role": "assistant", "content": str(assistant_msg)}) elif isinstance(turn, str): out.append({"role": "user", "content": turn}) return out # ---------------- sync completion ---------------- def _get_chat_response_sync(client: InferenceClient, messages, max_tokens=SUMMARY_MAX_TOKENS, temperature=0.3, top_p=0.9): try: resp = client.chat_completion(messages, max_tokens=max_tokens, temperature=temperature, top_p=top_p, stream=False) except Exception as e: print("sync chat_completion error:", e) return "" try: choices = resp.get("choices") if isinstance(resp, dict) else getattr(resp, "choices", None) if choices: c0 = choices[0] msg = c0.get("message") if isinstance(c0, dict) else getattr(c0, "message", None) if isinstance(msg, dict): return msg.get("content", "") return getattr(msg, "content", "") or str(msg or "") except Exception: pass return "" # ---------------- web search ---------------- def web_search(query, num_results=3): try: with DDGS() as ddgs: results = list(ddgs.text(query, max_results=num_results)) search_context = "๐Ÿ” Web Search Results:\n\n" for i, r in enumerate(results, 1): title = r.get("title", "")[:200] body = r.get("body", "")[:200].replace("\n", " ") href = r.get("href", "") search_context += f"{i}. {title}\n{body}...\nSource: {href}\n\n" return search_context except Exception as e: return f"โŒ Search error: {str(e)}" # ---------------- summarization ---------------- def summarize_old_messages(client: InferenceClient, old_messages): text = "\n".join([f"{m['role']}: {m['content']}" for m in old_messages]) system = {"role": "system", "content": "You are a summarizer. Summarize <=150 words."} user = {"role": "user", "content": text} return _get_chat_response_sync(client, [system, user]) # ---------------- memory tools ---------------- def show_memory(hf_token: gr.OAuthToken | None = None): user = get_user_id(hf_token) p = memory_file_path(user) if not os.path.exists(p): return "โ„น๏ธ No memory file found for user: " + user with open(p, "r", encoding="utf-8") as f: return f.read() def clear_memory(hf_token: gr.OAuthToken | None = None): user = get_user_id(hf_token) p = memory_file_path(user) if os.path.exists(p): os.remove(p) return f"โœ… Memory cleared for {user}" return "โ„น๏ธ No memory to clear." # ---------------- main chat ---------------- def respond(message, history: list, system_message, max_tokens, temperature, top_p, enable_search, enable_persistent_memory, hf_token: gr.OAuthToken = None): client = InferenceClient(token=(hf_token.token if hf_token else None), model=MODEL_ID) user_id = get_user_id(hf_token) memory = load_memory(user_id) if enable_persistent_memory else {"short_term": [], "long_term": ""} session_history = normalize_history(history) combined = memory.get("short_term", []) + session_history if len(combined) > SHORT_TERM_LIMIT: to_summarize = combined[:len(combined) - SHORT_TERM_LIMIT] summary = summarize_old_messages(client, to_summarize) if summary: memory["long_term"] = (memory.get("long_term", "") + "\n" + summary).strip() combined = combined[-SHORT_TERM_LIMIT:] combined.append({"role": "user", "content": message}) memory["short_term"] = combined if enable_persistent_memory: save_memory(user_id, memory) messages = [{"role": "system", "content": system_message}] if memory.get("long_term"): messages.append({"role": "system", "content": "Long-term memory:\n" + memory["long_term"]}) messages.extend(memory["short_term"]) if enable_search and any(k in message.lower() for k in ["search", "google", "tin tแปฉc", "news", "what is"]): sr = web_search(message) messages.append({"role": "user", "content": f"{sr}\n\nBased on search results, answer: {message}"}) response = "" try: for chunk in client.chat_completion(messages, max_tokens=int(max_tokens), stream=True, temperature=float(temperature), top_p=float(top_p)): choices = chunk.get("choices") if isinstance(chunk, dict) else getattr(chunk, "choices", None) if not choices: continue c0 = choices[0] delta = c0.get("delta") if isinstance(c0, dict) else getattr(c0, "delta", None) token = None if delta and (delta.get("content") if isinstance(delta, dict) else getattr(delta, "content", None)): token = delta.get("content") if isinstance(delta, dict) else getattr(delta, "content", None) else: msg = c0.get("message") if isinstance(c0, dict) else getattr(c0, "message", None) if isinstance(msg, dict): token = msg.get("content", "") else: token = getattr(msg, "content", None) or str(msg or "") if token: response += token yield response except Exception as e: yield f"โš ๏ธ Inference error: {e}" return memory["short_term"].append({"role": "assistant", "content": response}) memory["short_term"] = memory["short_term"][-SHORT_TERM_LIMIT:] if enable_persistent_memory: save_memory(user_id, memory) # ---------------- Gradio UI ---------------- chatbot = gr.ChatInterface( respond, type="messages", additional_inputs=[ gr.Textbox(value="You are a helpful AI assistant.", label="System message"), gr.Slider(1, 2048, value=512, step=1, label="Max new tokens"), gr.Slider(0.1, 4.0, value=0.7, step=0.1, label="Temperature"), gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p"), gr.Checkbox(value=True, label="Enable Web Search ๐Ÿ”"), gr.Checkbox(value=True, label="Enable Persistent Memory"), ], ) with gr.Blocks(title="AI Chatbot (full version)") as demo: gr.Markdown("# ๐Ÿค– AI Chatbot with Memory + Web Search + Datasets") with gr.Sidebar(): gr.LoginButton() gr.Markdown("### Memory Tools") gr.Button("๐Ÿ‘€ Show Memory").click(show_memory, inputs=None, outputs=gr.Textbox(label="Memory")) gr.Button("๐Ÿ—‘๏ธ Clear Memory").click(clear_memory, inputs=None, outputs=gr.Textbox(label="Status")) chatbot.render() if __name__ == "__main__": demo.launch()