import os import json import requests import re import torch from threading import Thread from transformers import ( AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, StoppingCriteria, StoppingCriteriaList ) from huggingface_hub import login, hf_hub_download from sentence_transformers import SentenceTransformer API_KEY = os.getenv("OPENROUTER_API_KEY") MODEL = os.getenv("OPENROUTER_MODEL", "google/gemma-2-9b-it:free") _embed_model = SentenceTransformer('all-MiniLM-L6-v2') class LocalModelHandler: def __init__(self, repo_id, device=None, use_quantization=False): """ Initializes the model and tokenizer. """ self.repo_id = repo_id self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") print(f"Loading local model: {repo_id} on {self.device}...") try: self.tokenizer = AutoTokenizer.from_pretrained(repo_id) # Load model arguments load_kwargs = { "torch_dtype": torch.bfloat16 if self.device == "cuda" else torch.float32, "low_cpu_mem_usage": True, "trust_remote_code": True } # Optional: 4-bit or 8-bit quantization if bitsandbytes is installed if use_quantization: load_kwargs["load_in_4bit"] = True self.model = AutoModelForCausalLM.from_pretrained( repo_id, **load_kwargs ) # Move to device if not using quantization (quantization handles device map auto) if not use_quantization: self.model.to(self.device) print("✅ Model loaded successfully.") except Exception as e: print(f"❌ Error loading model: {e}") self.model = None self.tokenizer = None def chat_stream(self, messages, max_new_tokens=512, temperature=0.5): """ Streams response exactly like the API-based chat_stream function. Args: messages (list): List of dicts [{'role': 'user', 'content': '...'}, ...] """ if not self.model or not self.tokenizer: yield " [Error: Model not loaded]" return try: # 1. Apply Chat Template (converts list of messages to prompt string) # Ensure the model supports chat templates, otherwise fallback to simple concatenation if getattr(self.tokenizer, "chat_template", None): prompt = self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) else: # Fallback for models without templates (Basic formatting) prompt = "" for msg in messages: prompt += f"{msg['role'].capitalize()}: {msg['content']}\n" prompt += "Assistant:" # 2. Tokenize inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) # 3. Setup Streamer streamer = TextIteratorStreamer( self.tokenizer, skip_prompt=True, skip_special_tokens=True ) # 4. Generation Arguments generation_kwargs = dict( inputs, streamer=streamer, max_new_tokens=max_new_tokens, temperature=temperature, do_sample=True if temperature > 0 else False, pad_token_id=self.tokenizer.eos_token_id ) # 5. Run Generation in a separate thread to allow streaming thread = Thread(target=self.model.generate, kwargs=generation_kwargs) thread.start() # 6. Yield tokens as they arrive for new_text in streamer: yield new_text except Exception as e: yield f" [Error generating response: {str(e)}]" def get_embedding(text): return _embed_model.encode(text).tolist() def chat_stream(messages): headers = { "Authorization": f"Bearer {API_KEY}", "Content-Type": "application/json", "HTTP-Referer": "http://localhost:5000", "X-Title": "VisuMem AI" } payload = {"model": MODEL, "messages": messages, "stream": True} try: resp = requests.post("https://openrouter.ai/api/v1/chat/completions", headers=headers, json=payload, stream=True) resp.raise_for_status() for line in resp.iter_lines(): if line: decoded = line.decode('utf-8') if decoded.startswith("data: ") and decoded != "data: [DONE]": try: data = json.loads(decoded[6:]) if "choices" in data: content = data["choices"][0].get("delta", {}).get("content", "") if content: yield content except: pass except Exception as e: yield f" [Error: {str(e)}]"