Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import requests | |
| import re | |
| import torch | |
| from threading import Thread | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoModelForCausalLM, | |
| TextIteratorStreamer, | |
| StoppingCriteria, | |
| StoppingCriteriaList | |
| ) | |
| from huggingface_hub import login, hf_hub_download | |
| from sentence_transformers import SentenceTransformer | |
| API_KEY = os.getenv("OPENROUTER_API_KEY") | |
| MODEL = os.getenv("OPENROUTER_MODEL", "google/gemma-2-9b-it:free") | |
| _embed_model = SentenceTransformer('all-MiniLM-L6-v2') | |
| class LocalModelHandler: | |
| def __init__(self, repo_id, device=None, use_quantization=False): | |
| """ | |
| Initializes the model and tokenizer. | |
| """ | |
| self.repo_id = repo_id | |
| self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Loading local model: {repo_id} on {self.device}...") | |
| try: | |
| self.tokenizer = AutoTokenizer.from_pretrained(repo_id) | |
| # Load model arguments | |
| load_kwargs = { | |
| "torch_dtype": torch.bfloat16 if self.device == "cuda" else torch.float32, | |
| "low_cpu_mem_usage": True, | |
| "trust_remote_code": True | |
| } | |
| # Optional: 4-bit or 8-bit quantization if bitsandbytes is installed | |
| if use_quantization: | |
| load_kwargs["load_in_4bit"] = True | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| repo_id, | |
| **load_kwargs | |
| ) | |
| # Move to device if not using quantization (quantization handles device map auto) | |
| if not use_quantization: | |
| self.model.to(self.device) | |
| print("✅ Model loaded successfully.") | |
| except Exception as e: | |
| print(f"❌ Error loading model: {e}") | |
| self.model = None | |
| self.tokenizer = None | |
| def chat_stream(self, messages, max_new_tokens=512, temperature=0.5): | |
| """ | |
| Streams response exactly like the API-based chat_stream function. | |
| Args: | |
| messages (list): List of dicts [{'role': 'user', 'content': '...'}, ...] | |
| """ | |
| if not self.model or not self.tokenizer: | |
| yield " [Error: Model not loaded]" | |
| return | |
| try: | |
| # 1. Apply Chat Template (converts list of messages to prompt string) | |
| # Ensure the model supports chat templates, otherwise fallback to simple concatenation | |
| if getattr(self.tokenizer, "chat_template", None): | |
| prompt = self.tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| else: | |
| # Fallback for models without templates (Basic formatting) | |
| prompt = "" | |
| for msg in messages: | |
| prompt += f"{msg['role'].capitalize()}: {msg['content']}\n" | |
| prompt += "Assistant:" | |
| # 2. Tokenize | |
| inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) | |
| # 3. Setup Streamer | |
| streamer = TextIteratorStreamer( | |
| self.tokenizer, | |
| skip_prompt=True, | |
| skip_special_tokens=True | |
| ) | |
| # 4. Generation Arguments | |
| generation_kwargs = dict( | |
| inputs, | |
| streamer=streamer, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| do_sample=True if temperature > 0 else False, | |
| pad_token_id=self.tokenizer.eos_token_id | |
| ) | |
| # 5. Run Generation in a separate thread to allow streaming | |
| thread = Thread(target=self.model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| # 6. Yield tokens as they arrive | |
| for new_text in streamer: | |
| yield new_text | |
| except Exception as e: | |
| yield f" [Error generating response: {str(e)}]" | |
| def get_embedding(text): | |
| return _embed_model.encode(text).tolist() | |
| def chat_stream(messages): | |
| headers = { | |
| "Authorization": f"Bearer {API_KEY}", | |
| "Content-Type": "application/json", | |
| "HTTP-Referer": "http://localhost:5000", | |
| "X-Title": "VisuMem AI" | |
| } | |
| payload = {"model": MODEL, "messages": messages, "stream": True} | |
| try: | |
| resp = requests.post("https://openrouter.ai/api/v1/chat/completions", headers=headers, json=payload, stream=True) | |
| resp.raise_for_status() | |
| for line in resp.iter_lines(): | |
| if line: | |
| decoded = line.decode('utf-8') | |
| if decoded.startswith("data: ") and decoded != "data: [DONE]": | |
| try: | |
| data = json.loads(decoded[6:]) | |
| if "choices" in data: | |
| content = data["choices"][0].get("delta", {}).get("content", "") | |
| if content: yield content | |
| except: pass | |
| except Exception as e: | |
| yield f" [Error: {str(e)}]" | |