Spaces:
Paused
Paused
File size: 9,098 Bytes
3c28fa4 36b6bbe edc66aa 3c28fa4 36b6bbe edc66aa 3c28fa4 36b6bbe 3c28fa4 36b6bbe 3c28fa4 36b6bbe 3c28fa4 36b6bbe 3c28fa4 36b6bbe 3c28fa4 36b6bbe 3c28fa4 36b6bbe edc66aa 36b6bbe edc66aa 3c28fa4 36b6bbe 3c28fa4 36b6bbe 3c28fa4 36b6bbe 3c28fa4 36b6bbe 3c28fa4 36b6bbe 8ce95a4 edc66aa 3c28fa4 edc66aa 36b6bbe edc66aa 3c28fa4 edc66aa 3c28fa4 36b6bbe 3c28fa4 edc66aa 3c28fa4 edc66aa 3c28fa4 edc66aa 8ce95a4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 |
# app.py — full version with memory + web search + datasets
import os
import json
import threading
import gradio as gr
from huggingface_hub import InferenceClient, snapshot_download
from datasets import load_dataset
from duckduckgo_search import DDGS
# ---------------- CONFIG ----------------
MODEL_ID = "openai/gpt-oss-120b" # or granite
DATA_DIR = "/data" if os.path.isdir("/data") else "./data"
os.makedirs(DATA_DIR, exist_ok=True)
SHORT_TERM_LIMIT = 10
SUMMARY_MAX_TOKENS = 150
MEMORY_LOCK = threading.Lock()
# ---------------- dataset loading ----------------
# ⚠️ Heavy startup, comment out if running on free HF Space
folder = snapshot_download(
"HuggingFaceFW/fineweb",
repo_type="dataset",
local_dir="./fineweb/",
allow_patterns="sample/10BT/*",
)
ds1 = load_dataset("HuggingFaceH4/ultrachat_200k")
ds2 = load_dataset("Anthropic/hh-rlhf")
# ---------------- helpers: memory ----------------
def get_user_id(hf_token: gr.OAuthToken | None):
if hf_token and getattr(hf_token, "token", None):
return "user_" + hf_token.token[:12]
return "anon"
def memory_file_path(user_id: str):
return os.path.join(DATA_DIR, f"memory_{user_id}.json")
def load_memory(user_id: str):
p = memory_file_path(user_id)
if os.path.exists(p):
try:
with open(p, "r", encoding="utf-8") as f:
mem = json.load(f)
if isinstance(mem, dict) and "short_term" in mem and "long_term" in mem:
return mem
except Exception as e:
print("load_memory error:", e)
return {"short_term": [], "long_term": ""}
def save_memory(user_id: str, memory: dict):
p = memory_file_path(user_id)
try:
with MEMORY_LOCK:
with open(p, "w", encoding="utf-8") as f:
json.dump(memory, f, ensure_ascii=False, indent=2)
except Exception as e:
print("save_memory error:", e)
# ---------------- normalize history ----------------
def normalize_history(history):
out = []
if not history: return out
for turn in history:
if isinstance(turn, dict) and "role" in turn and "content" in turn:
out.append({"role": turn["role"], "content": str(turn["content"])})
elif isinstance(turn, (list, tuple)) and len(turn) == 2:
user_msg, assistant_msg = turn
out.append({"role": "user", "content": str(user_msg)})
out.append({"role": "assistant", "content": str(assistant_msg)})
elif isinstance(turn, str):
out.append({"role": "user", "content": turn})
return out
# ---------------- sync completion ----------------
def _get_chat_response_sync(client: InferenceClient, messages, max_tokens=SUMMARY_MAX_TOKENS, temperature=0.3, top_p=0.9):
try:
resp = client.chat_completion(messages, max_tokens=max_tokens, temperature=temperature, top_p=top_p, stream=False)
except Exception as e:
print("sync chat_completion error:", e)
return ""
try:
choices = resp.get("choices") if isinstance(resp, dict) else getattr(resp, "choices", None)
if choices:
c0 = choices[0]
msg = c0.get("message") if isinstance(c0, dict) else getattr(c0, "message", None)
if isinstance(msg, dict):
return msg.get("content", "")
return getattr(msg, "content", "") or str(msg or "")
except Exception:
pass
return ""
# ---------------- web search ----------------
def web_search(query, num_results=3):
try:
with DDGS() as ddgs:
results = list(ddgs.text(query, max_results=num_results))
search_context = "🔍 Web Search Results:\n\n"
for i, r in enumerate(results, 1):
title = r.get("title", "")[:200]
body = r.get("body", "")[:200].replace("\n", " ")
href = r.get("href", "")
search_context += f"{i}. {title}\n{body}...\nSource: {href}\n\n"
return search_context
except Exception as e:
return f"❌ Search error: {str(e)}"
# ---------------- summarization ----------------
def summarize_old_messages(client: InferenceClient, old_messages):
text = "\n".join([f"{m['role']}: {m['content']}" for m in old_messages])
system = {"role": "system", "content": "You are a summarizer. Summarize <=150 words."}
user = {"role": "user", "content": text}
return _get_chat_response_sync(client, [system, user])
# ---------------- memory tools ----------------
def show_memory(hf_token: gr.OAuthToken | None = None):
user = get_user_id(hf_token)
p = memory_file_path(user)
if not os.path.exists(p):
return "ℹ️ No memory file found for user: " + user
with open(p, "r", encoding="utf-8") as f:
return f.read()
def clear_memory(hf_token: gr.OAuthToken | None = None):
user = get_user_id(hf_token)
p = memory_file_path(user)
if os.path.exists(p):
os.remove(p)
return f"✅ Memory cleared for {user}"
return "ℹ️ No memory to clear."
# ---------------- main chat ----------------
def respond(message, history: list, system_message, max_tokens, temperature, top_p,
enable_search, enable_persistent_memory, hf_token: gr.OAuthToken = None):
client = InferenceClient(token=(hf_token.token if hf_token else None), model=MODEL_ID)
user_id = get_user_id(hf_token)
memory = load_memory(user_id) if enable_persistent_memory else {"short_term": [], "long_term": ""}
session_history = normalize_history(history)
combined = memory.get("short_term", []) + session_history
if len(combined) > SHORT_TERM_LIMIT:
to_summarize = combined[:len(combined) - SHORT_TERM_LIMIT]
summary = summarize_old_messages(client, to_summarize)
if summary:
memory["long_term"] = (memory.get("long_term", "") + "\n" + summary).strip()
combined = combined[-SHORT_TERM_LIMIT:]
combined.append({"role": "user", "content": message})
memory["short_term"] = combined
if enable_persistent_memory:
save_memory(user_id, memory)
messages = [{"role": "system", "content": system_message}]
if memory.get("long_term"):
messages.append({"role": "system", "content": "Long-term memory:\n" + memory["long_term"]})
messages.extend(memory["short_term"])
if enable_search and any(k in message.lower() for k in ["search", "google", "tin tức", "news", "what is"]):
sr = web_search(message)
messages.append({"role": "user", "content": f"{sr}\n\nBased on search results, answer: {message}"})
response = ""
try:
for chunk in client.chat_completion(messages, max_tokens=int(max_tokens),
stream=True, temperature=float(temperature), top_p=float(top_p)):
choices = chunk.get("choices") if isinstance(chunk, dict) else getattr(chunk, "choices", None)
if not choices: continue
c0 = choices[0]
delta = c0.get("delta") if isinstance(c0, dict) else getattr(c0, "delta", None)
token = None
if delta and (delta.get("content") if isinstance(delta, dict) else getattr(delta, "content", None)):
token = delta.get("content") if isinstance(delta, dict) else getattr(delta, "content", None)
else:
msg = c0.get("message") if isinstance(c0, dict) else getattr(c0, "message", None)
if isinstance(msg, dict):
token = msg.get("content", "")
else:
token = getattr(msg, "content", None) or str(msg or "")
if token:
response += token
yield response
except Exception as e:
yield f"⚠️ Inference error: {e}"
return
memory["short_term"].append({"role": "assistant", "content": response})
memory["short_term"] = memory["short_term"][-SHORT_TERM_LIMIT:]
if enable_persistent_memory:
save_memory(user_id, memory)
# ---------------- Gradio UI ----------------
chatbot = gr.ChatInterface(
respond,
type="messages",
additional_inputs=[
gr.Textbox(value="You are a helpful AI assistant.", label="System message"),
gr.Slider(1, 2048, value=512, step=1, label="Max new tokens"),
gr.Slider(0.1, 4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p"),
gr.Checkbox(value=True, label="Enable Web Search 🔍"),
gr.Checkbox(value=True, label="Enable Persistent Memory"),
],
)
with gr.Blocks(title="AI Chatbot (full version)") as demo:
gr.Markdown("# 🤖 AI Chatbot with Memory + Web Search + Datasets")
with gr.Sidebar():
gr.LoginButton()
gr.Markdown("### Memory Tools")
gr.Button("👀 Show Memory").click(show_memory, inputs=None, outputs=gr.Textbox(label="Memory"))
gr.Button("🗑️ Clear Memory").click(clear_memory, inputs=None, outputs=gr.Textbox(label="Status"))
chatbot.render()
if __name__ == "__main__":
demo.launch() |