File size: 9,098 Bytes
3c28fa4
 
36b6bbe
 
 
edc66aa
3c28fa4
36b6bbe
 
edc66aa
3c28fa4
36b6bbe
3c28fa4
36b6bbe
 
 
 
 
 
 
3c28fa4
 
 
 
 
 
 
36b6bbe
3c28fa4
 
36b6bbe
3c28fa4
36b6bbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3c28fa4
36b6bbe
 
 
 
 
 
 
 
 
 
 
 
 
 
3c28fa4
36b6bbe
 
 
 
 
 
edc66aa
36b6bbe
 
 
 
 
 
 
 
 
 
 
edc66aa
3c28fa4
36b6bbe
 
 
 
 
 
 
 
 
 
 
 
 
 
3c28fa4
36b6bbe
 
 
 
 
 
3c28fa4
36b6bbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3c28fa4
36b6bbe
3c28fa4
36b6bbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ce95a4
edc66aa
3c28fa4
 
 
 
edc66aa
36b6bbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
edc66aa
3c28fa4
edc66aa
 
 
 
3c28fa4
36b6bbe
 
 
3c28fa4
 
edc66aa
 
 
3c28fa4
 
edc66aa
 
3c28fa4
 
 
edc66aa
 
 
8ce95a4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
# app.py — full version with memory + web search + datasets

import os
import json
import threading
import gradio as gr
from huggingface_hub import InferenceClient, snapshot_download
from datasets import load_dataset
from duckduckgo_search import DDGS


# ---------------- CONFIG ----------------
MODEL_ID = "openai/gpt-oss-120b"   # or granite
DATA_DIR = "/data" if os.path.isdir("/data") else "./data"
os.makedirs(DATA_DIR, exist_ok=True)

SHORT_TERM_LIMIT = 10
SUMMARY_MAX_TOKENS = 150
MEMORY_LOCK = threading.Lock()

# ---------------- dataset loading ----------------
# ⚠️ Heavy startup, comment out if running on free HF Space
folder = snapshot_download(
    "HuggingFaceFW/fineweb",
    repo_type="dataset",
    local_dir="./fineweb/",
    allow_patterns="sample/10BT/*",
)
ds1 = load_dataset("HuggingFaceH4/ultrachat_200k")
ds2 = load_dataset("Anthropic/hh-rlhf")

# ---------------- helpers: memory ----------------
def get_user_id(hf_token: gr.OAuthToken | None):
    if hf_token and getattr(hf_token, "token", None):
        return "user_" + hf_token.token[:12]
    return "anon"

def memory_file_path(user_id: str):
    return os.path.join(DATA_DIR, f"memory_{user_id}.json")

def load_memory(user_id: str):
    p = memory_file_path(user_id)
    if os.path.exists(p):
        try:
            with open(p, "r", encoding="utf-8") as f:
                mem = json.load(f)
                if isinstance(mem, dict) and "short_term" in mem and "long_term" in mem:
                    return mem
        except Exception as e:
            print("load_memory error:", e)
    return {"short_term": [], "long_term": ""}

def save_memory(user_id: str, memory: dict):
    p = memory_file_path(user_id)
    try:
        with MEMORY_LOCK:
            with open(p, "w", encoding="utf-8") as f:
                json.dump(memory, f, ensure_ascii=False, indent=2)
    except Exception as e:
        print("save_memory error:", e)

# ---------------- normalize history ----------------
def normalize_history(history):
    out = []
    if not history: return out
    for turn in history:
        if isinstance(turn, dict) and "role" in turn and "content" in turn:
            out.append({"role": turn["role"], "content": str(turn["content"])})
        elif isinstance(turn, (list, tuple)) and len(turn) == 2:
            user_msg, assistant_msg = turn
            out.append({"role": "user", "content": str(user_msg)})
            out.append({"role": "assistant", "content": str(assistant_msg)})
        elif isinstance(turn, str):
            out.append({"role": "user", "content": turn})
    return out

# ---------------- sync completion ----------------
def _get_chat_response_sync(client: InferenceClient, messages, max_tokens=SUMMARY_MAX_TOKENS, temperature=0.3, top_p=0.9):
    try:
        resp = client.chat_completion(messages, max_tokens=max_tokens, temperature=temperature, top_p=top_p, stream=False)
    except Exception as e:
        print("sync chat_completion error:", e)
        return ""

    try:
        choices = resp.get("choices") if isinstance(resp, dict) else getattr(resp, "choices", None)
        if choices:
            c0 = choices[0]
            msg = c0.get("message") if isinstance(c0, dict) else getattr(c0, "message", None)
            if isinstance(msg, dict):
                return msg.get("content", "")
            return getattr(msg, "content", "") or str(msg or "")
    except Exception:
        pass
    return ""

# ---------------- web search ----------------
def web_search(query, num_results=3):
    try:
        with DDGS() as ddgs:
            results = list(ddgs.text(query, max_results=num_results))
        search_context = "🔍 Web Search Results:\n\n"
        for i, r in enumerate(results, 1):
            title = r.get("title", "")[:200]
            body = r.get("body", "")[:200].replace("\n", " ")
            href = r.get("href", "")
            search_context += f"{i}. {title}\n{body}...\nSource: {href}\n\n"
        return search_context
    except Exception as e:
        return f"❌ Search error: {str(e)}"

# ---------------- summarization ----------------
def summarize_old_messages(client: InferenceClient, old_messages):
    text = "\n".join([f"{m['role']}: {m['content']}" for m in old_messages])
    system = {"role": "system", "content": "You are a summarizer. Summarize <=150 words."}
    user = {"role": "user", "content": text}
    return _get_chat_response_sync(client, [system, user])

# ---------------- memory tools ----------------
def show_memory(hf_token: gr.OAuthToken | None = None):
    user = get_user_id(hf_token)
    p = memory_file_path(user)
    if not os.path.exists(p):
        return "ℹ️ No memory file found for user: " + user
    with open(p, "r", encoding="utf-8") as f:
        return f.read()

def clear_memory(hf_token: gr.OAuthToken | None = None):
    user = get_user_id(hf_token)
    p = memory_file_path(user)
    if os.path.exists(p):
        os.remove(p)
        return f"✅ Memory cleared for {user}"
    return "ℹ️ No memory to clear."

# ---------------- main chat ----------------
def respond(message, history: list, system_message, max_tokens, temperature, top_p,
            enable_search, enable_persistent_memory, hf_token: gr.OAuthToken = None):

    client = InferenceClient(token=(hf_token.token if hf_token else None), model=MODEL_ID)
    user_id = get_user_id(hf_token)
    memory = load_memory(user_id) if enable_persistent_memory else {"short_term": [], "long_term": ""}

    session_history = normalize_history(history)
    combined = memory.get("short_term", []) + session_history

    if len(combined) > SHORT_TERM_LIMIT:
        to_summarize = combined[:len(combined) - SHORT_TERM_LIMIT]
        summary = summarize_old_messages(client, to_summarize)
        if summary:
            memory["long_term"] = (memory.get("long_term", "") + "\n" + summary).strip()
        combined = combined[-SHORT_TERM_LIMIT:]

    combined.append({"role": "user", "content": message})
    memory["short_term"] = combined
    if enable_persistent_memory:
        save_memory(user_id, memory)

    messages = [{"role": "system", "content": system_message}]
    if memory.get("long_term"):
        messages.append({"role": "system", "content": "Long-term memory:\n" + memory["long_term"]})
    messages.extend(memory["short_term"])

    if enable_search and any(k in message.lower() for k in ["search", "google", "tin tức", "news", "what is"]):
        sr = web_search(message)
        messages.append({"role": "user", "content": f"{sr}\n\nBased on search results, answer: {message}"})

    response = ""
    try:
        for chunk in client.chat_completion(messages, max_tokens=int(max_tokens),
                                            stream=True, temperature=float(temperature), top_p=float(top_p)):
            choices = chunk.get("choices") if isinstance(chunk, dict) else getattr(chunk, "choices", None)
            if not choices: continue
            c0 = choices[0]
            delta = c0.get("delta") if isinstance(c0, dict) else getattr(c0, "delta", None)
            token = None
            if delta and (delta.get("content") if isinstance(delta, dict) else getattr(delta, "content", None)):
                token = delta.get("content") if isinstance(delta, dict) else getattr(delta, "content", None)
            else:
                msg = c0.get("message") if isinstance(c0, dict) else getattr(c0, "message", None)
                if isinstance(msg, dict):
                    token = msg.get("content", "")
                else:
                    token = getattr(msg, "content", None) or str(msg or "")
            if token:
                response += token
                yield response
    except Exception as e:
        yield f"⚠️ Inference error: {e}"
        return

    memory["short_term"].append({"role": "assistant", "content": response})
    memory["short_term"] = memory["short_term"][-SHORT_TERM_LIMIT:]
    if enable_persistent_memory:
        save_memory(user_id, memory)

# ---------------- Gradio UI ----------------
chatbot = gr.ChatInterface(
    respond,
    type="messages",
    additional_inputs=[
        gr.Textbox(value="You are a helpful AI assistant.", label="System message"),
        gr.Slider(1, 2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(0.1, 4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p"),
        gr.Checkbox(value=True, label="Enable Web Search 🔍"),
        gr.Checkbox(value=True, label="Enable Persistent Memory"),
    ],
)

with gr.Blocks(title="AI Chatbot (full version)") as demo:
    gr.Markdown("# 🤖 AI Chatbot with Memory + Web Search + Datasets")
    with gr.Sidebar():
        gr.LoginButton()
        gr.Markdown("### Memory Tools")
        gr.Button("👀 Show Memory").click(show_memory, inputs=None, outputs=gr.Textbox(label="Memory"))
        gr.Button("🗑️ Clear Memory").click(clear_memory, inputs=None, outputs=gr.Textbox(label="Status"))
    chatbot.render()

if __name__ == "__main__":
    demo.launch()