File size: 5,227 Bytes
d0addd7
 
 
 
dbabe41
 
 
 
 
 
 
 
 
3eb9ffa
24cdd61
3eb9ffa
d0addd7
 
 
24cdd61
 
dbabe41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eca7b5d
dbabe41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3eb9ffa
 
d0addd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import os
import json
import requests
import re
import torch
from threading import Thread
from transformers import (
    AutoTokenizer, 
    AutoModelForCausalLM, 
    TextIteratorStreamer, 
    StoppingCriteria, 
    StoppingCriteriaList
)
from huggingface_hub import login, hf_hub_download
from sentence_transformers import SentenceTransformer


API_KEY = os.getenv("OPENROUTER_API_KEY")
MODEL = os.getenv("OPENROUTER_MODEL", "google/gemma-2-9b-it:free")

_embed_model = SentenceTransformer('all-MiniLM-L6-v2')

class LocalModelHandler:
    def __init__(self, repo_id, device=None, use_quantization=False):
        """
        Initializes the model and tokenizer.
        """
        self.repo_id = repo_id
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        
        print(f"Loading local model: {repo_id} on {self.device}...")
        
        try:
            self.tokenizer = AutoTokenizer.from_pretrained(repo_id)
            
            # Load model arguments
            load_kwargs = {
                "torch_dtype": torch.bfloat16 if self.device == "cuda" else torch.float32,
                "low_cpu_mem_usage": True,
                "trust_remote_code": True
            }
            
            # Optional: 4-bit or 8-bit quantization if bitsandbytes is installed
            if use_quantization:
                load_kwargs["load_in_4bit"] = True
            
            self.model = AutoModelForCausalLM.from_pretrained(
                repo_id, 
                **load_kwargs
            )
            
            # Move to device if not using quantization (quantization handles device map auto)
            if not use_quantization:
                self.model.to(self.device)
                
            print("✅ Model loaded successfully.")
            
        except Exception as e:
            print(f"❌ Error loading model: {e}")
            self.model = None
            self.tokenizer = None

    def chat_stream(self, messages, max_new_tokens=512, temperature=0.5):
        """
        Streams response exactly like the API-based chat_stream function.
        Args:
            messages (list): List of dicts [{'role': 'user', 'content': '...'}, ...]
        """
        if not self.model or not self.tokenizer:
            yield " [Error: Model not loaded]"
            return

        try:
            # 1. Apply Chat Template (converts list of messages to prompt string)
            # Ensure the model supports chat templates, otherwise fallback to simple concatenation
            if getattr(self.tokenizer, "chat_template", None):
                prompt = self.tokenizer.apply_chat_template(
                    messages, 
                    tokenize=False, 
                    add_generation_prompt=True
                )
            else:
                # Fallback for models without templates (Basic formatting)
                prompt = ""
                for msg in messages:
                    prompt += f"{msg['role'].capitalize()}: {msg['content']}\n"
                prompt += "Assistant:"

            # 2. Tokenize
            inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)

            # 3. Setup Streamer
            streamer = TextIteratorStreamer(
                self.tokenizer, 
                skip_prompt=True, 
                skip_special_tokens=True
            )

            # 4. Generation Arguments
            generation_kwargs = dict(
                inputs,
                streamer=streamer,
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                do_sample=True if temperature > 0 else False,
                pad_token_id=self.tokenizer.eos_token_id
            )

            # 5. Run Generation in a separate thread to allow streaming
            thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
            thread.start()

            # 6. Yield tokens as they arrive
            for new_text in streamer:
                yield new_text

        except Exception as e:
            yield f" [Error generating response: {str(e)}]"


def get_embedding(text):
    return _embed_model.encode(text).tolist()

def chat_stream(messages):
    headers = {
        "Authorization": f"Bearer {API_KEY}",
        "Content-Type": "application/json",
        "HTTP-Referer": "http://localhost:5000",
        "X-Title": "VisuMem AI"
    }
    payload = {"model": MODEL, "messages": messages, "stream": True}
    
    try:
        resp = requests.post("https://openrouter.ai/api/v1/chat/completions", headers=headers, json=payload, stream=True)
        resp.raise_for_status()
        for line in resp.iter_lines():
            if line:
                decoded = line.decode('utf-8')
                if decoded.startswith("data: ") and decoded != "data: [DONE]":
                    try:
                        data = json.loads(decoded[6:])
                        if "choices" in data:
                            content = data["choices"][0].get("delta", {}).get("content", "")
                            if content: yield content
                    except: pass
    except Exception as e:
        yield f" [Error: {str(e)}]"