Spaces:
Sleeping
Sleeping
File size: 5,227 Bytes
d0addd7 dbabe41 3eb9ffa 24cdd61 3eb9ffa d0addd7 24cdd61 dbabe41 eca7b5d dbabe41 3eb9ffa d0addd7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
import os
import json
import requests
import re
import torch
from threading import Thread
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TextIteratorStreamer,
StoppingCriteria,
StoppingCriteriaList
)
from huggingface_hub import login, hf_hub_download
from sentence_transformers import SentenceTransformer
API_KEY = os.getenv("OPENROUTER_API_KEY")
MODEL = os.getenv("OPENROUTER_MODEL", "google/gemma-2-9b-it:free")
_embed_model = SentenceTransformer('all-MiniLM-L6-v2')
class LocalModelHandler:
def __init__(self, repo_id, device=None, use_quantization=False):
"""
Initializes the model and tokenizer.
"""
self.repo_id = repo_id
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
print(f"Loading local model: {repo_id} on {self.device}...")
try:
self.tokenizer = AutoTokenizer.from_pretrained(repo_id)
# Load model arguments
load_kwargs = {
"torch_dtype": torch.bfloat16 if self.device == "cuda" else torch.float32,
"low_cpu_mem_usage": True,
"trust_remote_code": True
}
# Optional: 4-bit or 8-bit quantization if bitsandbytes is installed
if use_quantization:
load_kwargs["load_in_4bit"] = True
self.model = AutoModelForCausalLM.from_pretrained(
repo_id,
**load_kwargs
)
# Move to device if not using quantization (quantization handles device map auto)
if not use_quantization:
self.model.to(self.device)
print("✅ Model loaded successfully.")
except Exception as e:
print(f"❌ Error loading model: {e}")
self.model = None
self.tokenizer = None
def chat_stream(self, messages, max_new_tokens=512, temperature=0.5):
"""
Streams response exactly like the API-based chat_stream function.
Args:
messages (list): List of dicts [{'role': 'user', 'content': '...'}, ...]
"""
if not self.model or not self.tokenizer:
yield " [Error: Model not loaded]"
return
try:
# 1. Apply Chat Template (converts list of messages to prompt string)
# Ensure the model supports chat templates, otherwise fallback to simple concatenation
if getattr(self.tokenizer, "chat_template", None):
prompt = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
else:
# Fallback for models without templates (Basic formatting)
prompt = ""
for msg in messages:
prompt += f"{msg['role'].capitalize()}: {msg['content']}\n"
prompt += "Assistant:"
# 2. Tokenize
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
# 3. Setup Streamer
streamer = TextIteratorStreamer(
self.tokenizer,
skip_prompt=True,
skip_special_tokens=True
)
# 4. Generation Arguments
generation_kwargs = dict(
inputs,
streamer=streamer,
max_new_tokens=max_new_tokens,
temperature=temperature,
do_sample=True if temperature > 0 else False,
pad_token_id=self.tokenizer.eos_token_id
)
# 5. Run Generation in a separate thread to allow streaming
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
thread.start()
# 6. Yield tokens as they arrive
for new_text in streamer:
yield new_text
except Exception as e:
yield f" [Error generating response: {str(e)}]"
def get_embedding(text):
return _embed_model.encode(text).tolist()
def chat_stream(messages):
headers = {
"Authorization": f"Bearer {API_KEY}",
"Content-Type": "application/json",
"HTTP-Referer": "http://localhost:5000",
"X-Title": "VisuMem AI"
}
payload = {"model": MODEL, "messages": messages, "stream": True}
try:
resp = requests.post("https://openrouter.ai/api/v1/chat/completions", headers=headers, json=payload, stream=True)
resp.raise_for_status()
for line in resp.iter_lines():
if line:
decoded = line.decode('utf-8')
if decoded.startswith("data: ") and decoded != "data: [DONE]":
try:
data = json.loads(decoded[6:])
if "choices" in data:
content = data["choices"][0].get("delta", {}).get("content", "")
if content: yield content
except: pass
except Exception as e:
yield f" [Error: {str(e)}]"
|