broadfield-dev commited on
Commit
dbabe41
·
verified ·
1 Parent(s): 3eb9ffa

Update ai_engine.py

Browse files
Files changed (1) hide show
  1. ai_engine.py +105 -38
ai_engine.py CHANGED
@@ -2,51 +2,118 @@ import os
2
  import json
3
  import requests
4
  import re
5
- import torcch
6
- from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, StoppingCriteria, StoppingCriteriaList
 
 
 
 
 
 
 
7
  from huggingface_hub import login, hf_hub_download
8
 
9
 
10
  API_KEY = os.getenv("OPENROUTER_API_KEY")
11
  MODEL = os.getenv("OPENROUTER_MODEL", "google/gemma-2-9b-it:free")
12
 
13
- def load_model(repo_id):
14
- if not repo_id:
15
- yield "Please enter a repo ID."
16
- return
17
-
18
- yield "Loading model...", state, gr.update(visible=False)
19
- try:
20
- tokenizer = AutoTokenizer.from_pretrained(repo_id)
21
- model = AutoModelForCausalLM.from_pretrained(repo_id, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True)
22
- state.update({"model": model, "tokenizer": tokenizer, "stopping_criteria": StoppingCriteriaList([StopOnNewline(tokenizer)])})
23
- except Exception as e:
24
- yield f"❌ Error loading model: {e}", state, gr.update(visible=False)
25
- return
26
-
27
- for status_update in knowledge_base.build_or_load(repo_id):
28
- yield status_update, state, gr.update(visible=False)
29
-
30
- final_status = "✅ Model and KB are ready."
31
- yield final_status, state, gr.update(visible=True)
32
 
33
- def respond(state, message, history, max_len, temp):
34
- model, tokenizer, stopping_criteria = state["model"], state["tokenizer"], state["stopping_criteria"]
35
- if not model:
36
- history.append((message, "Model not loaded.")); return history
37
-
38
- context = knowledge_base.search(message, k=5)
39
- prompt = f"Context:\n{context}\n\nQuestion: {message}\n\nAnswer:"
40
- inputs = tokenizer(prompt, return_tensors="pt")
41
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
42
-
43
- generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": int(max_len), "temperature": float(temp), "do_sample": True, "stopping_criteria": stopping_criteria}
44
-
45
- Thread(target=model.generate, kwargs=generation_kwargs).start()
46
- history.append((message, ""))
47
- for new_text in streamer:
48
- history[-1] = (message, history[-1][1] + new_text)
49
- yield history
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
 
52
  # Singleton for embedding model
 
2
  import json
3
  import requests
4
  import re
5
+ import torch
6
+ from threading import Thread
7
+ from transformers import (
8
+ AutoTokenizer,
9
+ AutoModelForCausalLM,
10
+ TextIteratorStreamer,
11
+ StoppingCriteria,
12
+ StoppingCriteriaList
13
+ )
14
  from huggingface_hub import login, hf_hub_download
15
 
16
 
17
  API_KEY = os.getenv("OPENROUTER_API_KEY")
18
  MODEL = os.getenv("OPENROUTER_MODEL", "google/gemma-2-9b-it:free")
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
+ class LocalModelHandler:
22
+ def __init__(self, repo_id, device=None, use_quantization=False):
23
+ """
24
+ Initializes the model and tokenizer.
25
+ """
26
+ self.repo_id = repo_id
27
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
28
+
29
+ print(f"Loading local model: {repo_id} on {self.device}...")
30
+
31
+ try:
32
+ self.tokenizer = AutoTokenizer.from_pretrained(repo_id)
33
+
34
+ # Load model arguments
35
+ load_kwargs = {
36
+ "torch_dtype": torch.bfloat16 if self.device == "cuda" else torch.float32,
37
+ "low_cpu_mem_usage": True,
38
+ "trust_remote_code": True
39
+ }
40
+
41
+ # Optional: 4-bit or 8-bit quantization if bitsandbytes is installed
42
+ if use_quantization:
43
+ load_kwargs["load_in_4bit"] = True
44
+
45
+ self.model = AutoModelForCausalLM.from_pretrained(
46
+ repo_id,
47
+ **load_kwargs
48
+ )
49
+
50
+ # Move to device if not using quantization (quantization handles device map auto)
51
+ if not use_quantization:
52
+ self.model.to(self.device)
53
+
54
+ print("✅ Model loaded successfully.")
55
+
56
+ except Exception as e:
57
+ print(f"❌ Error loading model: {e}")
58
+ self.model = None
59
+ self.tokenizer = None
60
+
61
+ def chat_stream(self, messages, max_new_tokens=512, temperature=0.7):
62
+ """
63
+ Streams response exactly like the API-based chat_stream function.
64
+ Args:
65
+ messages (list): List of dicts [{'role': 'user', 'content': '...'}, ...]
66
+ """
67
+ if not self.model or not self.tokenizer:
68
+ yield " [Error: Model not loaded]"
69
+ return
70
+
71
+ try:
72
+ # 1. Apply Chat Template (converts list of messages to prompt string)
73
+ # Ensure the model supports chat templates, otherwise fallback to simple concatenation
74
+ if getattr(self.tokenizer, "chat_template", None):
75
+ prompt = self.tokenizer.apply_chat_template(
76
+ messages,
77
+ tokenize=False,
78
+ add_generation_prompt=True
79
+ )
80
+ else:
81
+ # Fallback for models without templates (Basic formatting)
82
+ prompt = ""
83
+ for msg in messages:
84
+ prompt += f"{msg['role'].capitalize()}: {msg['content']}\n"
85
+ prompt += "Assistant:"
86
+
87
+ # 2. Tokenize
88
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
89
+
90
+ # 3. Setup Streamer
91
+ streamer = TextIteratorStreamer(
92
+ self.tokenizer,
93
+ skip_prompt=True,
94
+ skip_special_tokens=True
95
+ )
96
+
97
+ # 4. Generation Arguments
98
+ generation_kwargs = dict(
99
+ inputs,
100
+ streamer=streamer,
101
+ max_new_tokens=max_new_tokens,
102
+ temperature=temperature,
103
+ do_sample=True if temperature > 0 else False,
104
+ pad_token_id=self.tokenizer.eos_token_id
105
+ )
106
+
107
+ # 5. Run Generation in a separate thread to allow streaming
108
+ thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
109
+ thread.start()
110
+
111
+ # 6. Yield tokens as they arrive
112
+ for new_text in streamer:
113
+ yield new_text
114
+
115
+ except Exception as e:
116
+ yield f" [Error generating response: {str(e)}]"
117
 
118
 
119
  # Singleton for embedding model