Zynara commited on
Commit
ffb81a0
·
verified ·
1 Parent(s): 6c4f151

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +370 -96
main.py CHANGED
@@ -1,137 +1,411 @@
1
 
 
 
 
 
2
  import torch
3
- from fastapi import FastAPI
4
- from pydantic import BaseModel
5
- from duckduckgo_search import ddg
6
- import chromadb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  from sentence_transformers import SentenceTransformer
8
- from transformers import AutoTokenizer, AutoModelForCausalLM
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  # ===============================
11
- # 1️⃣ Load Model (Llama-3-8B-Instruct)
12
  # ===============================
13
- MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
 
14
 
15
  print("🚀 Loading Billy AI model...")
16
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
17
- model = AutoModelForCausalLM.from_pretrained(
18
- MODEL_ID,
19
- torch_dtype=torch.float32, # CPU-friendly
20
- device_map="auto"
21
- )
22
 
23
- def generate_text(prompt: str, max_tokens: int = 512) -> str:
24
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
25
- output = model.generate(
26
- **inputs,
27
- max_new_tokens=max_tokens,
28
- do_sample=True,
29
- temperature=0.7,
30
- top_p=0.9
31
- )
32
- return tokenizer.decode(output[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  # ===============================
35
- # 2️⃣ Setup RAG (Memory + Search)
36
  # ===============================
37
- db = chromadb.PersistentClient(path="./billy_rag_db")
38
  try:
39
- collection = db.get_collection("billy_rag")
40
- except:
41
- collection = db.create_collection("billy_rag")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
- embedder = SentenceTransformer("all-MiniLM-L6-v2")
 
44
 
45
- def search_web(query: str):
 
46
  try:
47
- results = ddg(query, max_results=3)
48
- return [r.get("body") or r.get("snippet") or "" for r in results if r]
49
- except:
50
- return []
 
 
 
 
 
 
 
 
 
51
 
52
- def store_knowledge(text: str):
53
- vec = embedder.encode(text).tolist()
54
  try:
55
- collection.add(documents=[text], embeddings=[vec], ids=[str(abs(hash(text)))])
56
- except:
 
 
 
 
 
 
 
 
 
57
  pass
58
 
59
- def retrieve_knowledge(query: str) -> str:
60
- vec = embedder.encode(query).tolist()
61
- results = collection.query(query_embeddings=[vec], n_results=3)
62
- return " ".join(results["documents"][0]) if results and results["documents"] else ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
  # ===============================
65
- # 3️⃣ Tool Functions
66
  # ===============================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  def summarize_text(text: str) -> str:
68
- prompt = f"Summarize the following text in simple terms:\n\n{text}\n\nSummary:"
69
- return generate_text(prompt, max_tokens=200)
 
 
70
 
71
  def translate_text(text: str, lang: str) -> str:
72
- prompt = f"Translate the following text to {lang}:\n\n{text}\n\nTranslation:"
73
- return generate_text(prompt, max_tokens=200)
 
 
74
 
75
  def explain_code(code: str) -> str:
76
- prompt = f"Explain the following code in simple terms:\n\n```{code}```\n\nExplanation:"
77
- return generate_text(prompt, max_tokens=300)
 
 
 
 
78
 
79
  # ===============================
80
- # 4️⃣ FastAPI App
81
  # ===============================
82
- app = FastAPI(title="Billy AI - Free Chatbot")
 
 
 
 
 
 
83
 
84
- class Query(BaseModel):
85
- message: str
86
- user_id: str = "anonymous"
 
 
87
 
88
- @app.post("/chat")
89
- def chat(req: Query):
90
- user_msg = req.message.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
- # --- Special Commands ---
93
- if user_msg.lower().startswith("/summarize "):
94
- return {"response": summarize_text(user_msg[11:])}
 
95
 
96
- if user_msg.lower().startswith("/translate "):
97
- try:
98
- lang, text = user_msg[10:].split(" ", 1)
99
- return {"response": translate_text(text, lang)}
100
- except:
101
- return {"response": "Format: /translate <language> <text>"}
102
-
103
- if user_msg.lower().startswith("/explaincode "):
104
- return {"response": explain_code(user_msg[13:])}
105
-
106
- # --- Search & RAG ---
107
- local_knowledge = retrieve_knowledge(user_msg)
108
-
109
- if not local_knowledge:
110
- web_results = search_web(user_msg)
111
- for r in web_results:
112
- if r.strip():
113
- store_knowledge(r)
114
- local_knowledge = " ".join(web_results)
115
-
116
- # --- Personality & Context ---
117
- context = (
118
- "You are Billy AI — a helpful, witty, and slightly funny AI assistant. "
119
- "You are a bit smarter than GPT-3.5, but not too advanced. "
120
- "When answering, be friendly, concise, and give useful info. "
121
- f"Use this info if helpful: {local_knowledge}\n\n"
122
- f"User: {user_msg}\nAssistant:"
123
- )
124
 
125
- reply = generate_text(context)
126
- return {"response": reply.strip()}
 
127
 
128
- @app.get("/")
129
- def home():
130
- return {"message": "Billy AI is running and ready to chat!"}
131
 
132
  # ===============================
133
- # 5️⃣ Run
134
  # ===============================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  if __name__ == "__main__":
136
- import uvicorn
137
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
1
 
2
+ import hashlib
3
+ import time
4
+ from typing import List, Dict, Any, Tuple, Optional
5
+
6
  import torch
7
+ import gradio as gr
8
+
9
+ # Optional deps (web search + vector store)
10
+ ddg = None
11
+ DDGS = None
12
+ try:
13
+ from duckduckgo_search import ddg as _ddg
14
+ ddg = _ddg
15
+ except Exception:
16
+ try:
17
+ from duckduckgo_search import DDGS as _DDGS
18
+ DDGS = _DDGS
19
+ except Exception:
20
+ ddg = None
21
+ DDGS = None
22
+
23
+ try:
24
+ import chromadb
25
+ except Exception:
26
+ chromadb = None
27
+
28
  from sentence_transformers import SentenceTransformer
29
+
30
+ from transformers import (
31
+ AutoTokenizer,
32
+ AutoModelForCausalLM,
33
+ )
34
+
35
+ # Optional quantization (4-bit on GPU)
36
+ BITSANDBYTES_AVAILABLE = False
37
+ try:
38
+ from transformers import BitsAndBytesConfig
39
+ BITSANDBYTES_AVAILABLE = True
40
+ except Exception:
41
+ BITSANDBYTES_AVAILABLE = False
42
 
43
  # ===============================
44
+ # 1) Model Setup (Llama-3.1-8B-Instruct)
45
  # ===============================
46
+ MODEL_ID = os.getenv("MODEL_ID", "meta-llama/Meta-Llama-3.1-8B-Instruct")
47
+ HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN")
48
 
49
  print("🚀 Loading Billy AI model...")
 
 
 
 
 
 
50
 
51
+ # Tokenizer
52
+ try:
53
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN)
54
+ except TypeError:
55
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_auth_token=HF_TOKEN)
56
+
57
+ if tokenizer.pad_token_id is None:
58
+ # Fallback to eos as pad if not set
59
+ tokenizer.pad_token_id = tokenizer.eos_token_id
60
+
61
+ def _gpu_bf16_supported() -> bool:
62
+ try:
63
+ return torch.cuda.is_available() and torch.cuda.is_bf16_supported()
64
+ except Exception:
65
+ return False
66
+
67
+ def _model_device(m) -> torch.device:
68
+ try:
69
+ return next(m.parameters()).device
70
+ except Exception:
71
+ return torch.device("cpu")
72
+
73
+ load_kwargs: Dict[str, Any] = {}
74
+ if torch.cuda.is_available():
75
+ if BITSANDBYTES_AVAILABLE:
76
+ print("⚙️ Using 4-bit quantization (bitsandbytes).")
77
+ compute_dtype = torch.bfloat16 if _gpu_bf16_supported() else torch.float16
78
+ bnb_config = BitsAndBytesConfig(
79
+ load_in_4bit=True,
80
+ bnb_4bit_quant_type="nf4",
81
+ bnb_4bit_compute_dtype=compute_dtype,
82
+ )
83
+ load_kwargs.update(dict(device_map="auto", quantization_config=bnb_config, token=HF_TOKEN))
84
+ else:
85
+ print("⚙️ No bitsandbytes: loading in half precision on GPU.")
86
+ load_kwargs.update(dict(device_map="auto",
87
+ torch_dtype=torch.bfloat16 if _gpu_bf16_supported() else torch.float16,
88
+ token=HF_TOKEN))
89
+ else:
90
+ print("⚠️ No GPU detected: CPU load (slow). Consider a smaller model or enable GPU runtime.")
91
+ load_kwargs.update(dict(torch_dtype=torch.float32, token=HF_TOKEN))
92
+
93
+ # Load model with fallbacks for auth kwarg differences
94
+ try:
95
+ model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **load_kwargs)
96
+ except TypeError:
97
+ load_kwargs.pop("token", None)
98
+ try:
99
+ model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **load_kwargs)
100
+ except TypeError:
101
+ model = AutoModelForCausalLM.from_pretrained(MODEL_ID, use_auth_token=HF_TOKEN, **load_kwargs)
102
+
103
+ MODEL_DEVICE = _model_device(model)
104
+ print(f"✅ Model loaded on: {MODEL_DEVICE}")
105
 
106
  # ===============================
107
+ # 2) Lightweight RAG (Embeddings + Optional Chroma + In-Memory Fallback)
108
  # ===============================
 
109
  try:
110
+ embedder = SentenceTransformer("all-MiniLM-L6-v2")
111
+ print("✅ Embedding model loaded.")
112
+ except Exception as e:
113
+ raise RuntimeError(f"Embedding model load failed: {e}")
114
+
115
+ # Optional Chroma persistent store; fallback to in-memory store if unavailable.
116
+ chroma_client = None
117
+ collection = None
118
+ if chromadb is not None:
119
+ try:
120
+ chroma_client = chromadb.PersistentClient(path="./billy_rag_db")
121
+ try:
122
+ collection = chroma_client.get_collection("billy_rag")
123
+ except Exception:
124
+ collection = chroma_client.create_collection("billy_rag")
125
+ print("✅ ChromaDB ready.")
126
+ except Exception as e:
127
+ print(f"⚠️ ChromaDB init failed: {e}; falling back to in-memory store.")
128
+
129
+ # In-memory store: list of dicts {text, embedding}
130
+ memory_store: List[Dict[str, Any]] = []
131
 
132
+ def _stable_id(text: str) -> str:
133
+ return hashlib.sha1(text.encode("utf-8")).hexdigest()
134
 
135
+ def search_web(query: str, max_results: int = 3) -> List[str]:
136
+ # Try legacy ddg function
137
  try:
138
+ if ddg is not None:
139
+ try:
140
+ results = ddg(query, max_results=max_results)
141
+ except TypeError:
142
+ results = ddg(keywords=query, max_results=max_results)
143
+ snippets = []
144
+ for r in results or []:
145
+ if not r:
146
+ continue
147
+ snippets.append(r.get("body") or r.get("snippet") or r.get("title") or "")
148
+ return [s for s in snippets if s and s.strip()]
149
+ except Exception:
150
+ pass
151
 
152
+ # Try modern DDGS client
 
153
  try:
154
+ if DDGS is not None:
155
+ with DDGS() as d:
156
+ results = list(d.text(query, max_results=max_results))
157
+ snippets = []
158
+ for r in results or []:
159
+ if not r:
160
+ continue
161
+ # r keys differ slightly in DDGS()
162
+ snippets.append(r.get("body") or r.get("snippet") or r.get("title") or r.get("href") or "")
163
+ return [s for s in snippets if s and s.strip()]
164
+ except Exception:
165
  pass
166
 
167
+ return []
168
+
169
+ def store_knowledge(text: str):
170
+ if not text or not text.strip():
171
+ return
172
+ try:
173
+ vec = embedder.encode(text).tolist()
174
+ except Exception:
175
+ return
176
+ if collection is not None:
177
+ try:
178
+ collection.add(
179
+ documents=[text],
180
+ embeddings=[vec],
181
+ ids=[_stable_id(text)],
182
+ metadatas=[{"source": "web_or_local"}],
183
+ )
184
+ return
185
+ except Exception:
186
+ pass
187
+ # Fallback: in-memory
188
+ memory_store.append({"text": text, "embedding": vec})
189
+
190
+ def _cosine(a: List[float], b: List[float]) -> float:
191
+ s = 0.0
192
+ na = 0.0
193
+ nb = 0.0
194
+ for x, y in zip(a, b):
195
+ s += x * y
196
+ na += x * x
197
+ nb += y * y
198
+ na = na ** 0.5 or 1.0
199
+ nb = nb ** 0.5 or 1.0
200
+ return s / (na * nb)
201
+
202
+ def retrieve_knowledge(query: str, k: int = 5) -> str:
203
+ try:
204
+ qvec = embedder.encode(query).tolist()
205
+ except Exception:
206
+ return ""
207
+ # Prefer Chroma if available
208
+ if collection is not None:
209
+ try:
210
+ res = collection.query(query_embeddings=[qvec], n_results=k)
211
+ docs = res.get("documents", [])
212
+ if docs and docs[0]:
213
+ return " ".join(docs[0])
214
+ except Exception:
215
+ pass
216
+ # In-memory cosine top-k
217
+ if not memory_store:
218
+ return ""
219
+ scored: List[Tuple[str, float]] = []
220
+ for item in memory_store:
221
+ scored.append((item["text"], _cosine(qvec, item["embedding"])))
222
+ scored.sort(key=lambda x: x[1], reverse=True)
223
+ return " ".join([t for t, _ in scored[:k]])
224
 
225
  # ===============================
226
+ # 3) Generation Utilities
227
  # ===============================
228
+ def build_messages(system_prompt: str, chat_history: List[Tuple[str, str]], user_prompt: str) -> List[Dict[str, str]]:
229
+ messages: List[Dict[str, str]] = [{"role": "system", "content": system_prompt}]
230
+ # chat_history is a list of (user, assistant) tuples
231
+ for u, a in chat_history or []:
232
+ if u:
233
+ messages.append({"role": "user", "content": u})
234
+ if a:
235
+ messages.append({"role": "assistant", "content": a})
236
+ messages.append({"role": "user", "content": user_prompt})
237
+ return messages
238
+
239
+ def apply_chat_template_from_messages(messages: List[Dict[str, str]]) -> str:
240
+ try:
241
+ return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
242
+ except Exception:
243
+ # Fallback to simple instruct style if no template provided
244
+ sys = ""
245
+ user = ""
246
+ # Extract the last system and user message for a minimal fallback
247
+ for m in messages:
248
+ if m["role"] == "system":
249
+ sys = m["content"]
250
+ elif m["role"] == "user":
251
+ user = m["content"]
252
+ sys = (sys or "").strip()
253
+ user = (user or "").strip()
254
+ prefix = f"{sys}\n\n" if sys else ""
255
+ return f"{prefix}User: {user}\nAssistant:"
256
+
257
+ def _get_eos_token_id():
258
+ eos_id = getattr(tokenizer, "eos_token_id", None)
259
+ if isinstance(eos_id, list) and eos_id:
260
+ return eos_id[0]
261
+ return eos_id
262
+
263
+ def generate_text(prompt_text: str,
264
+ max_tokens: int = 600,
265
+ temperature: float = 0.6,
266
+ top_p: float = 0.9) -> str:
267
+ inputs = tokenizer(prompt_text, return_tensors="pt")
268
+ inputs = {k: v.to(MODEL_DEVICE) for k, v in inputs.items()}
269
+ output_ids = model.generate(
270
+ **inputs,
271
+ max_new_tokens=min(max_tokens, 2048),
272
+ do_sample=True,
273
+ temperature=temperature,
274
+ top_p=top_p,
275
+ pad_token_id=tokenizer.pad_token_id,
276
+ eos_token_id=_get_eos_token_id(),
277
+ )
278
+ text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
279
+ # Best-effort: strip the prompt echo if present
280
+ if text.startswith(prompt_text):
281
+ return text[len(prompt_text):].strip()
282
+ return text.strip()
283
+
284
  def summarize_text(text: str) -> str:
285
+ system = "You are Billy AI a precise, helpful summarizer."
286
+ user = f"Summarize the following text in simple, clear bullet points (max 6 bullets):\n\n{text}"
287
+ messages = build_messages(system, [], user)
288
+ return generate_text(apply_chat_template_from_messages(messages), max_tokens=220, temperature=0.3, top_p=0.9)
289
 
290
  def translate_text(text: str, lang: str) -> str:
291
+ system = "You are Billy AI an expert translator."
292
+ user = f"Translate the following text to {lang} while preserving meaning and tone:\n\n{text}"
293
+ messages = build_messages(system, [], user)
294
+ return generate_text(apply_chat_template_from_messages(messages), max_tokens=220, temperature=0.3, top_p=0.9)
295
 
296
  def explain_code(code: str) -> str:
297
+ system = "You are Billy AI an expert software engineer and teacher."
298
+ user = ("Explain the following code step by step for a mid-level developer. "
299
+ "Include what it does, complexity, pitfalls, and an improved version if relevant.\n\n"
300
+ f"{code}")
301
+ messages = build_messages(system, [], user)
302
+ return generate_text(apply_chat_template_from_messages(messages), max_tokens=400, temperature=0.5, top_p=0.9)
303
 
304
  # ===============================
305
+ # 4) Chat Orchestration
306
  # ===============================
307
+ def make_system_prompt(local_knowledge: str) -> str:
308
+ base = ("You are Billy AI — a helpful, witty, and precise assistant. "
309
+ "You tend to outperform GPT-3.5 on reasoning, explanation, and coding tasks. "
310
+ "Be concise but thorough; use bullet points for clarity; cite assumptions; avoid hallucinations.")
311
+ if local_knowledge:
312
+ base += f"\nUseful context: {local_knowledge[:3000]}"
313
+ return base
314
 
315
+ def _ingest_search(query: str, max_results: int = 3) -> int:
316
+ snips = search_web(query, max_results=max_results)
317
+ for s in snips:
318
+ store_knowledge(s)
319
+ return len(snips)
320
 
321
+ def _parse_translate_command(cmd: str) -> Tuple[Optional[str], Optional[str]]:
322
+ # Supports patterns:
323
+ # /translate <lang>: <text>
324
+ # /translate <lang> | <text>
325
+ # /translate <lang> <text>
326
+ rest = cmd[len("/translate"):].strip()
327
+ if not rest:
328
+ return None, None
329
+ # Try separators
330
+ for sep in [":", "|"]:
331
+ if sep in rest:
332
+ lang, text = rest.split(sep, 1)
333
+ return lang.strip(), text.strip()
334
+ parts = rest.split(None, 1)
335
+ if len(parts) == 2:
336
+ return parts[0].strip(), parts[1].strip()
337
+ return None, None
338
 
339
+ def handle_message(message: str, chat_history: List[Tuple[str, str]]) -> str:
340
+ msg = (message or "").strip()
341
+ if not msg:
342
+ return "Please send a non-empty message."
343
 
344
+ # Slash commands
345
+ low = msg.lower()
346
+ if low.startswith("/summarize "):
347
+ return summarize_text(msg[len("/summarize "):].strip() or "Nothing to summarize.")
348
+ if low.startswith("/explain "):
349
+ return explain_code(message[len("/explain "):].strip())
350
+ if low.startswith("/translate"):
351
+ lang, txt = _parse_translate_command(message)
352
+ if not lang or not txt:
353
+ return "Usage: /translate <lang>: <text>"
354
+ return translate_text(txt, lang)
355
+ if low.startswith("/search "):
356
+ q = message[len("/search "):].strip()
357
+ if not q:
358
+ return "Usage: /search <query>"
359
+ n = _ingest_search(q, max_results=5)
360
+ ctx = retrieve_knowledge(q, k=5)
361
+ if n == 0 and not ctx:
362
+ return "No results found or web search unavailable."
363
+ return f"Ingested {n} snippet(s). Context now includes:\n\n{ctx[:1000]}"
364
+
365
+ if low.startswith("/remember "):
366
+ t = message[len("/remember "):].strip()
367
+ if not t:
368
+ return "Usage: /remember <text>"
369
+ store_knowledge(t)
370
+ return "Saved to knowledge base."
 
371
 
372
+ # RAG: retrieve related knowledge
373
+ local_knowledge = retrieve_knowledge(msg, k=5)
374
+ system_prompt = make_system_prompt(local_knowledge)
375
 
376
+ messages = build_messages(system_prompt, chat_history, msg)
377
+ prompt = apply_chat_template_from_messages(messages)
378
+ return generate_text(prompt, max_tokens=600, temperature=0.6, top_p=0.9)
379
 
380
  # ===============================
381
+ # 5) Gradio UI
382
  # ===============================
383
+ def respond(message, history):
384
+ # history is a list of [user, assistant] pairs
385
+ # Convert history to list of tuples[str, str]
386
+ tuples: List[Tuple[str, str]] = []
387
+ for turn in history or []:
388
+ if isinstance(turn, (list, tuple)) and len(turn) == 2:
389
+ u = turn[0] if turn[0] is not None else ""
390
+ a = turn[1] if turn[1] is not None else ""
391
+ tuples.append((str(u), str(a)))
392
+ try:
393
+ return handle_message(message, tuples)
394
+ except Exception as e:
395
+ return f"Error: {e}"
396
+
397
+ with gr.Blocks(title="Billy AI") as demo:
398
+ gr.Markdown("## Billy AI")
399
+ gr.Markdown(
400
+ "Commands: /summarize <text>, /explain <code>, /translate <lang>: <text>, /search <query>, /remember <text>"
401
+ )
402
+ chat = gr.ChatInterface(
403
+ fn=respond,
404
+ title="Billy AI",
405
+ theme="soft",
406
+ cache_examples=False,
407
+ )
408
+
409
  if __name__ == "__main__":
410
+ # Share=False by default; set to True if you want a public link
411
+ demo.launch()