vismem / ai_engine.py
broadfield-dev's picture
Update ai_engine.py
eca7b5d verified
import os
import json
import requests
import re
import torch
from threading import Thread
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TextIteratorStreamer,
StoppingCriteria,
StoppingCriteriaList
)
from huggingface_hub import login, hf_hub_download
from sentence_transformers import SentenceTransformer
API_KEY = os.getenv("OPENROUTER_API_KEY")
MODEL = os.getenv("OPENROUTER_MODEL", "google/gemma-2-9b-it:free")
_embed_model = SentenceTransformer('all-MiniLM-L6-v2')
class LocalModelHandler:
def __init__(self, repo_id, device=None, use_quantization=False):
"""
Initializes the model and tokenizer.
"""
self.repo_id = repo_id
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
print(f"Loading local model: {repo_id} on {self.device}...")
try:
self.tokenizer = AutoTokenizer.from_pretrained(repo_id)
# Load model arguments
load_kwargs = {
"torch_dtype": torch.bfloat16 if self.device == "cuda" else torch.float32,
"low_cpu_mem_usage": True,
"trust_remote_code": True
}
# Optional: 4-bit or 8-bit quantization if bitsandbytes is installed
if use_quantization:
load_kwargs["load_in_4bit"] = True
self.model = AutoModelForCausalLM.from_pretrained(
repo_id,
**load_kwargs
)
# Move to device if not using quantization (quantization handles device map auto)
if not use_quantization:
self.model.to(self.device)
print("✅ Model loaded successfully.")
except Exception as e:
print(f"❌ Error loading model: {e}")
self.model = None
self.tokenizer = None
def chat_stream(self, messages, max_new_tokens=512, temperature=0.5):
"""
Streams response exactly like the API-based chat_stream function.
Args:
messages (list): List of dicts [{'role': 'user', 'content': '...'}, ...]
"""
if not self.model or not self.tokenizer:
yield " [Error: Model not loaded]"
return
try:
# 1. Apply Chat Template (converts list of messages to prompt string)
# Ensure the model supports chat templates, otherwise fallback to simple concatenation
if getattr(self.tokenizer, "chat_template", None):
prompt = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
else:
# Fallback for models without templates (Basic formatting)
prompt = ""
for msg in messages:
prompt += f"{msg['role'].capitalize()}: {msg['content']}\n"
prompt += "Assistant:"
# 2. Tokenize
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
# 3. Setup Streamer
streamer = TextIteratorStreamer(
self.tokenizer,
skip_prompt=True,
skip_special_tokens=True
)
# 4. Generation Arguments
generation_kwargs = dict(
inputs,
streamer=streamer,
max_new_tokens=max_new_tokens,
temperature=temperature,
do_sample=True if temperature > 0 else False,
pad_token_id=self.tokenizer.eos_token_id
)
# 5. Run Generation in a separate thread to allow streaming
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
thread.start()
# 6. Yield tokens as they arrive
for new_text in streamer:
yield new_text
except Exception as e:
yield f" [Error generating response: {str(e)}]"
def get_embedding(text):
return _embed_model.encode(text).tolist()
def chat_stream(messages):
headers = {
"Authorization": f"Bearer {API_KEY}",
"Content-Type": "application/json",
"HTTP-Referer": "http://localhost:5000",
"X-Title": "VisuMem AI"
}
payload = {"model": MODEL, "messages": messages, "stream": True}
try:
resp = requests.post("https://openrouter.ai/api/v1/chat/completions", headers=headers, json=payload, stream=True)
resp.raise_for_status()
for line in resp.iter_lines():
if line:
decoded = line.decode('utf-8')
if decoded.startswith("data: ") and decoded != "data: [DONE]":
try:
data = json.loads(decoded[6:])
if "choices" in data:
content = data["choices"][0].get("delta", {}).get("content", "")
if content: yield content
except: pass
except Exception as e:
yield f" [Error: {str(e)}]"