Spaces:
Running
Running
Update ai_engine.py
Browse files- ai_engine.py +43 -0
ai_engine.py
CHANGED
|
@@ -2,10 +2,53 @@ import os
|
|
| 2 |
import json
|
| 3 |
import requests
|
| 4 |
import re
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
API_KEY = os.getenv("OPENROUTER_API_KEY")
|
| 7 |
MODEL = os.getenv("OPENROUTER_MODEL", "google/gemma-2-9b-it:free")
|
| 8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
# Singleton for embedding model
|
| 10 |
_embed_model = None
|
| 11 |
|
|
|
|
| 2 |
import json
|
| 3 |
import requests
|
| 4 |
import re
|
| 5 |
+
import torcch
|
| 6 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, StoppingCriteria, StoppingCriteriaList
|
| 7 |
+
from huggingface_hub import login, hf_hub_download
|
| 8 |
+
|
| 9 |
|
| 10 |
API_KEY = os.getenv("OPENROUTER_API_KEY")
|
| 11 |
MODEL = os.getenv("OPENROUTER_MODEL", "google/gemma-2-9b-it:free")
|
| 12 |
|
| 13 |
+
def load_model(repo_id):
|
| 14 |
+
if not repo_id:
|
| 15 |
+
yield "Please enter a repo ID."
|
| 16 |
+
return
|
| 17 |
+
|
| 18 |
+
yield "Loading model...", state, gr.update(visible=False)
|
| 19 |
+
try:
|
| 20 |
+
tokenizer = AutoTokenizer.from_pretrained(repo_id)
|
| 21 |
+
model = AutoModelForCausalLM.from_pretrained(repo_id, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True)
|
| 22 |
+
state.update({"model": model, "tokenizer": tokenizer, "stopping_criteria": StoppingCriteriaList([StopOnNewline(tokenizer)])})
|
| 23 |
+
except Exception as e:
|
| 24 |
+
yield f"❌ Error loading model: {e}", state, gr.update(visible=False)
|
| 25 |
+
return
|
| 26 |
+
|
| 27 |
+
for status_update in knowledge_base.build_or_load(repo_id):
|
| 28 |
+
yield status_update, state, gr.update(visible=False)
|
| 29 |
+
|
| 30 |
+
final_status = "✅ Model and KB are ready."
|
| 31 |
+
yield final_status, state, gr.update(visible=True)
|
| 32 |
+
|
| 33 |
+
def respond(state, message, history, max_len, temp):
|
| 34 |
+
model, tokenizer, stopping_criteria = state["model"], state["tokenizer"], state["stopping_criteria"]
|
| 35 |
+
if not model:
|
| 36 |
+
history.append((message, "Model not loaded.")); return history
|
| 37 |
+
|
| 38 |
+
context = knowledge_base.search(message, k=5)
|
| 39 |
+
prompt = f"Context:\n{context}\n\nQuestion: {message}\n\nAnswer:"
|
| 40 |
+
inputs = tokenizer(prompt, return_tensors="pt")
|
| 41 |
+
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
| 42 |
+
|
| 43 |
+
generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": int(max_len), "temperature": float(temp), "do_sample": True, "stopping_criteria": stopping_criteria}
|
| 44 |
+
|
| 45 |
+
Thread(target=model.generate, kwargs=generation_kwargs).start()
|
| 46 |
+
history.append((message, ""))
|
| 47 |
+
for new_text in streamer:
|
| 48 |
+
history[-1] = (message, history[-1][1] + new_text)
|
| 49 |
+
yield history
|
| 50 |
+
|
| 51 |
+
|
| 52 |
# Singleton for embedding model
|
| 53 |
_embed_model = None
|
| 54 |
|