File size: 14,568 Bytes
6c4f151
ffb81a0
 
 
 
6c4f151
ffb81a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c4f151
ffb81a0
 
 
 
 
 
 
 
 
 
 
 
 
6c4f151
 
ffb81a0
6c4f151
ffb81a0
 
6c4f151
 
 
ffb81a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c4f151
 
ffb81a0
6c4f151
 
ffb81a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c4f151
ffb81a0
 
6c4f151
ffb81a0
 
6c4f151
ffb81a0
 
 
 
 
 
 
 
 
 
 
 
 
6c4f151
ffb81a0
6c4f151
ffb81a0
 
 
 
 
 
 
 
 
 
 
6c4f151
 
ffb81a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c4f151
 
ffb81a0
6c4f151
ffb81a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c4f151
ffb81a0
 
 
 
6c4f151
 
ffb81a0
 
 
 
6c4f151
 
ffb81a0
 
 
 
 
 
6c4f151
 
ffb81a0
6c4f151
ffb81a0
 
 
 
 
 
 
6c4f151
ffb81a0
 
 
 
 
6c4f151
ffb81a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c4f151
ffb81a0
 
 
 
6c4f151
ffb81a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c4f151
ffb81a0
 
 
6c4f151
ffb81a0
 
 
6c4f151
 
ffb81a0
6c4f151
ffb81a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c4f151
ffb81a0
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412

import hashlib
import time
from typing import List, Dict, Any, Tuple, Optional

import torch
import gradio as gr

# Optional deps (web search + vector store)
ddg = None
DDGS = None
try:
    from duckduckgo_search import ddg as _ddg
    ddg = _ddg
except Exception:
    try:
        from duckduckgo_search import DDGS as _DDGS
        DDGS = _DDGS
    except Exception:
        ddg = None
        DDGS = None

try:
    import chromadb
except Exception:
    chromadb = None

from sentence_transformers import SentenceTransformer

from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
)

# Optional quantization (4-bit on GPU)
BITSANDBYTES_AVAILABLE = False
try:
    from transformers import BitsAndBytesConfig
    BITSANDBYTES_AVAILABLE = True
except Exception:
    BITSANDBYTES_AVAILABLE = False

# ===============================
# 1) Model Setup (Llama-3.1-8B-Instruct)
# ===============================
MODEL_ID = os.getenv("MODEL_ID", "meta-llama/Meta-Llama-3.1-8B-Instruct")
HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN")

print("🚀 Loading Billy AI model...")

# Tokenizer
try:
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN)
except TypeError:
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_auth_token=HF_TOKEN)

if tokenizer.pad_token_id is None:
    # Fallback to eos as pad if not set
    tokenizer.pad_token_id = tokenizer.eos_token_id

def _gpu_bf16_supported() -> bool:
    try:
        return torch.cuda.is_available() and torch.cuda.is_bf16_supported()
    except Exception:
        return False

def _model_device(m) -> torch.device:
    try:
        return next(m.parameters()).device
    except Exception:
        return torch.device("cpu")

load_kwargs: Dict[str, Any] = {}
if torch.cuda.is_available():
    if BITSANDBYTES_AVAILABLE:
        print("⚙️ Using 4-bit quantization (bitsandbytes).")
        compute_dtype = torch.bfloat16 if _gpu_bf16_supported() else torch.float16
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=compute_dtype,
        )
        load_kwargs.update(dict(device_map="auto", quantization_config=bnb_config, token=HF_TOKEN))
    else:
        print("⚙️ No bitsandbytes: loading in half precision on GPU.")
        load_kwargs.update(dict(device_map="auto",
                                torch_dtype=torch.bfloat16 if _gpu_bf16_supported() else torch.float16,
                                token=HF_TOKEN))
else:
    print("⚠️ No GPU detected: CPU load (slow). Consider a smaller model or enable GPU runtime.")
    load_kwargs.update(dict(torch_dtype=torch.float32, token=HF_TOKEN))

# Load model with fallbacks for auth kwarg differences
try:
    model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **load_kwargs)
except TypeError:
    load_kwargs.pop("token", None)
    try:
        model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **load_kwargs)
    except TypeError:
        model = AutoModelForCausalLM.from_pretrained(MODEL_ID, use_auth_token=HF_TOKEN, **load_kwargs)

MODEL_DEVICE = _model_device(model)
print(f"✅ Model loaded on: {MODEL_DEVICE}")

# ===============================
# 2) Lightweight RAG (Embeddings + Optional Chroma + In-Memory Fallback)
# ===============================
try:
    embedder = SentenceTransformer("all-MiniLM-L6-v2")
    print("✅ Embedding model loaded.")
except Exception as e:
    raise RuntimeError(f"Embedding model load failed: {e}")

# Optional Chroma persistent store; fallback to in-memory store if unavailable.
chroma_client = None
collection = None
if chromadb is not None:
    try:
        chroma_client = chromadb.PersistentClient(path="./billy_rag_db")
        try:
            collection = chroma_client.get_collection("billy_rag")
        except Exception:
            collection = chroma_client.create_collection("billy_rag")
        print("✅ ChromaDB ready.")
    except Exception as e:
        print(f"⚠️ ChromaDB init failed: {e}; falling back to in-memory store.")

# In-memory store: list of dicts {text, embedding}
memory_store: List[Dict[str, Any]] = []

def _stable_id(text: str) -> str:
    return hashlib.sha1(text.encode("utf-8")).hexdigest()

def search_web(query: str, max_results: int = 3) -> List[str]:
    # Try legacy ddg function
    try:
        if ddg is not None:
            try:
                results = ddg(query, max_results=max_results)
            except TypeError:
                results = ddg(keywords=query, max_results=max_results)
            snippets = []
            for r in results or []:
                if not r:
                    continue
                snippets.append(r.get("body") or r.get("snippet") or r.get("title") or "")
            return [s for s in snippets if s and s.strip()]
    except Exception:
        pass

    # Try modern DDGS client
    try:
        if DDGS is not None:
            with DDGS() as d:
                results = list(d.text(query, max_results=max_results))
            snippets = []
            for r in results or []:
                if not r:
                    continue
                # r keys differ slightly in DDGS()
                snippets.append(r.get("body") or r.get("snippet") or r.get("title") or r.get("href") or "")
            return [s for s in snippets if s and s.strip()]
    except Exception:
        pass

    return []

def store_knowledge(text: str):
    if not text or not text.strip():
        return
    try:
        vec = embedder.encode(text).tolist()
    except Exception:
        return
    if collection is not None:
        try:
            collection.add(
                documents=[text],
                embeddings=[vec],
                ids=[_stable_id(text)],
                metadatas=[{"source": "web_or_local"}],
            )
            return
        except Exception:
            pass
    # Fallback: in-memory
    memory_store.append({"text": text, "embedding": vec})

def _cosine(a: List[float], b: List[float]) -> float:
    s = 0.0
    na = 0.0
    nb = 0.0
    for x, y in zip(a, b):
        s += x * y
        na += x * x
        nb += y * y
    na = na ** 0.5 or 1.0
    nb = nb ** 0.5 or 1.0
    return s / (na * nb)

def retrieve_knowledge(query: str, k: int = 5) -> str:
    try:
        qvec = embedder.encode(query).tolist()
    except Exception:
        return ""
    # Prefer Chroma if available
    if collection is not None:
        try:
            res = collection.query(query_embeddings=[qvec], n_results=k)
            docs = res.get("documents", [])
            if docs and docs[0]:
                return " ".join(docs[0])
        except Exception:
            pass
    # In-memory cosine top-k
    if not memory_store:
        return ""
    scored: List[Tuple[str, float]] = []
    for item in memory_store:
        scored.append((item["text"], _cosine(qvec, item["embedding"])))
    scored.sort(key=lambda x: x[1], reverse=True)
    return " ".join([t for t, _ in scored[:k]])

# ===============================
# 3) Generation Utilities
# ===============================
def build_messages(system_prompt: str, chat_history: List[Tuple[str, str]], user_prompt: str) -> List[Dict[str, str]]:
    messages: List[Dict[str, str]] = [{"role": "system", "content": system_prompt}]
    # chat_history is a list of (user, assistant) tuples
    for u, a in chat_history or []:
        if u:
            messages.append({"role": "user", "content": u})
        if a:
            messages.append({"role": "assistant", "content": a})
    messages.append({"role": "user", "content": user_prompt})
    return messages

def apply_chat_template_from_messages(messages: List[Dict[str, str]]) -> str:
    try:
        return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    except Exception:
        # Fallback to simple instruct style if no template provided
        sys = ""
        user = ""
        # Extract the last system and user message for a minimal fallback
        for m in messages:
            if m["role"] == "system":
                sys = m["content"]
            elif m["role"] == "user":
                user = m["content"]
        sys = (sys or "").strip()
        user = (user or "").strip()
        prefix = f"{sys}\n\n" if sys else ""
        return f"{prefix}User: {user}\nAssistant:"

def _get_eos_token_id():
    eos_id = getattr(tokenizer, "eos_token_id", None)
    if isinstance(eos_id, list) and eos_id:
        return eos_id[0]
    return eos_id

def generate_text(prompt_text: str,
                  max_tokens: int = 600,
                  temperature: float = 0.6,
                  top_p: float = 0.9) -> str:
    inputs = tokenizer(prompt_text, return_tensors="pt")
    inputs = {k: v.to(MODEL_DEVICE) for k, v in inputs.items()}
    output_ids = model.generate(
        **inputs,
        max_new_tokens=min(max_tokens, 2048),
        do_sample=True,
        temperature=temperature,
        top_p=top_p,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=_get_eos_token_id(),
    )
    text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    # Best-effort: strip the prompt echo if present
    if text.startswith(prompt_text):
        return text[len(prompt_text):].strip()
    return text.strip()

def summarize_text(text: str) -> str:
    system = "You are Billy AI — a precise, helpful summarizer."
    user = f"Summarize the following text in simple, clear bullet points (max 6 bullets):\n\n{text}"
    messages = build_messages(system, [], user)
    return generate_text(apply_chat_template_from_messages(messages), max_tokens=220, temperature=0.3, top_p=0.9)

def translate_text(text: str, lang: str) -> str:
    system = "You are Billy AI — an expert translator."
    user = f"Translate the following text to {lang} while preserving meaning and tone:\n\n{text}"
    messages = build_messages(system, [], user)
    return generate_text(apply_chat_template_from_messages(messages), max_tokens=220, temperature=0.3, top_p=0.9)

def explain_code(code: str) -> str:
    system = "You are Billy AI — an expert software engineer and teacher."
    user = ("Explain the following code step by step for a mid-level developer. "
            "Include what it does, complexity, pitfalls, and an improved version if relevant.\n\n"
            f"{code}")
    messages = build_messages(system, [], user)
    return generate_text(apply_chat_template_from_messages(messages), max_tokens=400, temperature=0.5, top_p=0.9)

# ===============================
# 4) Chat Orchestration
# ===============================
def make_system_prompt(local_knowledge: str) -> str:
    base = ("You are Billy AI — a helpful, witty, and precise assistant. "
            "You tend to outperform GPT-3.5 on reasoning, explanation, and coding tasks. "
            "Be concise but thorough; use bullet points for clarity; cite assumptions; avoid hallucinations.")
    if local_knowledge:
        base += f"\nUseful context: {local_knowledge[:3000]}"
    return base

def _ingest_search(query: str, max_results: int = 3) -> int:
    snips = search_web(query, max_results=max_results)
    for s in snips:
        store_knowledge(s)
    return len(snips)

def _parse_translate_command(cmd: str) -> Tuple[Optional[str], Optional[str]]:
    # Supports patterns:
    # /translate <lang>: <text>
    # /translate <lang> | <text>
    # /translate <lang> <text>
    rest = cmd[len("/translate"):].strip()
    if not rest:
        return None, None
    # Try separators
    for sep in [":", "|"]:
        if sep in rest:
            lang, text = rest.split(sep, 1)
            return lang.strip(), text.strip()
    parts = rest.split(None, 1)
    if len(parts) == 2:
        return parts[0].strip(), parts[1].strip()
    return None, None

def handle_message(message: str, chat_history: List[Tuple[str, str]]) -> str:
    msg = (message or "").strip()
    if not msg:
        return "Please send a non-empty message."

    # Slash commands
    low = msg.lower()
    if low.startswith("/summarize "):
        return summarize_text(msg[len("/summarize "):].strip() or "Nothing to summarize.")
    if low.startswith("/explain "):
        return explain_code(message[len("/explain "):].strip())
    if low.startswith("/translate"):
        lang, txt = _parse_translate_command(message)
        if not lang or not txt:
            return "Usage: /translate <lang>: <text>"
        return translate_text(txt, lang)
    if low.startswith("/search "):
        q = message[len("/search "):].strip()
        if not q:
            return "Usage: /search <query>"
        n = _ingest_search(q, max_results=5)
        ctx = retrieve_knowledge(q, k=5)
        if n == 0 and not ctx:
            return "No results found or web search unavailable."
        return f"Ingested {n} snippet(s). Context now includes:\n\n{ctx[:1000]}"

    if low.startswith("/remember "):
        t = message[len("/remember "):].strip()
        if not t:
            return "Usage: /remember <text>"
        store_knowledge(t)
        return "Saved to knowledge base."

    # RAG: retrieve related knowledge
    local_knowledge = retrieve_knowledge(msg, k=5)
    system_prompt = make_system_prompt(local_knowledge)

    messages = build_messages(system_prompt, chat_history, msg)
    prompt = apply_chat_template_from_messages(messages)
    return generate_text(prompt, max_tokens=600, temperature=0.6, top_p=0.9)

# ===============================
# 5) Gradio UI
# ===============================
def respond(message, history):
    # history is a list of [user, assistant] pairs
    # Convert history to list of tuples[str, str]
    tuples: List[Tuple[str, str]] = []
    for turn in history or []:
        if isinstance(turn, (list, tuple)) and len(turn) == 2:
            u = turn[0] if turn[0] is not None else ""
            a = turn[1] if turn[1] is not None else ""
            tuples.append((str(u), str(a)))
    try:
        return handle_message(message, tuples)
    except Exception as e:
        return f"Error: {e}"

with gr.Blocks(title="Billy AI") as demo:
    gr.Markdown("## Billy AI")
    gr.Markdown(
        "Commands: /summarize <text>, /explain <code>, /translate <lang>: <text>, /search <query>, /remember <text>"
    )
    chat = gr.ChatInterface(
        fn=respond,
        title="Billy AI",
        theme="soft",
        cache_examples=False,
    )

if __name__ == "__main__":
    # Share=False by default; set to True if you want a public link
    demo.launch()