broadfield-dev commited on
Commit
3eb9ffa
·
verified ·
1 Parent(s): d0addd7

Update ai_engine.py

Browse files
Files changed (1) hide show
  1. ai_engine.py +43 -0
ai_engine.py CHANGED
@@ -2,10 +2,53 @@ import os
2
  import json
3
  import requests
4
  import re
 
 
 
 
5
 
6
  API_KEY = os.getenv("OPENROUTER_API_KEY")
7
  MODEL = os.getenv("OPENROUTER_MODEL", "google/gemma-2-9b-it:free")
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  # Singleton for embedding model
10
  _embed_model = None
11
 
 
2
  import json
3
  import requests
4
  import re
5
+ import torcch
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, StoppingCriteria, StoppingCriteriaList
7
+ from huggingface_hub import login, hf_hub_download
8
+
9
 
10
  API_KEY = os.getenv("OPENROUTER_API_KEY")
11
  MODEL = os.getenv("OPENROUTER_MODEL", "google/gemma-2-9b-it:free")
12
 
13
+ def load_model(repo_id):
14
+ if not repo_id:
15
+ yield "Please enter a repo ID."
16
+ return
17
+
18
+ yield "Loading model...", state, gr.update(visible=False)
19
+ try:
20
+ tokenizer = AutoTokenizer.from_pretrained(repo_id)
21
+ model = AutoModelForCausalLM.from_pretrained(repo_id, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True)
22
+ state.update({"model": model, "tokenizer": tokenizer, "stopping_criteria": StoppingCriteriaList([StopOnNewline(tokenizer)])})
23
+ except Exception as e:
24
+ yield f"❌ Error loading model: {e}", state, gr.update(visible=False)
25
+ return
26
+
27
+ for status_update in knowledge_base.build_or_load(repo_id):
28
+ yield status_update, state, gr.update(visible=False)
29
+
30
+ final_status = "✅ Model and KB are ready."
31
+ yield final_status, state, gr.update(visible=True)
32
+
33
+ def respond(state, message, history, max_len, temp):
34
+ model, tokenizer, stopping_criteria = state["model"], state["tokenizer"], state["stopping_criteria"]
35
+ if not model:
36
+ history.append((message, "Model not loaded.")); return history
37
+
38
+ context = knowledge_base.search(message, k=5)
39
+ prompt = f"Context:\n{context}\n\nQuestion: {message}\n\nAnswer:"
40
+ inputs = tokenizer(prompt, return_tensors="pt")
41
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
42
+
43
+ generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": int(max_len), "temperature": float(temp), "do_sample": True, "stopping_criteria": stopping_criteria}
44
+
45
+ Thread(target=model.generate, kwargs=generation_kwargs).start()
46
+ history.append((message, ""))
47
+ for new_text in streamer:
48
+ history[-1] = (message, history[-1][1] + new_text)
49
+ yield history
50
+
51
+
52
  # Singleton for embedding model
53
  _embed_model = None
54