File size: 9,642 Bytes
de2021f
f92a42b
 
de2021f
 
 
 
 
 
 
 
 
 
f92a42b
de2021f
 
 
 
f92a42b
de2021f
 
f92a42b
de2021f
 
 
 
 
f92a42b
 
de2021f
 
 
 
 
 
 
 
f92a42b
de2021f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f92a42b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
de2021f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f92a42b
de2021f
f92a42b
de2021f
f92a42b
de2021f
f92a42b
 
 
 
 
 
 
 
 
de2021f
 
 
 
 
 
 
 
 
f92a42b
de2021f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f92a42b
 
 
de2021f
 
 
f92a42b
de2021f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f92a42b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import logging
from typing import Dict, List
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from config import settings

logger = logging.getLogger(__name__)

class CharacterManager:
    """Lightweight character manager using PEFT adapter switching"""
    
    def __init__(self):
        self.base_model = None
        self.tokenizer = None
        self.peft_model = None  # Single PeftModel with multiple adapters
        self.current_character = None
        self.character_prompts = {}
        self.available_adapters = []
        
    async def initialize(self):
        """Initialize base model ONCE and load all character LoRA adapters"""
        logger.info("πŸ”„ Loading base model (ONE instance for all characters)...")
        
        # MUST use Qwen3-0.6B - this is what the LoRA adapters were trained on!
        model_name = "Qwen/Qwen3-0.6B"
        
        try:
            self.tokenizer = AutoTokenizer.from_pretrained(
                model_name,
                trust_remote_code=True,
                use_fast=True
            )
            
            # Load base model ONCE
            self.base_model = AutoModelForCausalLM.from_pretrained(
                model_name,
                torch_dtype=torch.float32,
                trust_remote_code=True,
                low_cpu_mem_usage=True
            )
            
            if self.tokenizer.pad_token is None:
                self.tokenizer.pad_token = self.tokenizer.eos_token
                
            logger.info(f"βœ… Base model loaded: {model_name}")
            
        except Exception as e:
            logger.error(f"❌ Failed to load base model: {e}")
            raise
        
        # Load character prompts
        self._load_character_prompts()
        
        # Load first character's adapter to create PeftModel, then add others
        characters = ["moses", "samsung_employee", "jinx"]
        first_loaded = False
        
        for idx, character_id in enumerate(characters):
            adapter_path = os.path.join(settings.LORA_ADAPTERS_PATH, character_id)
            adapter_model_path = os.path.join(adapter_path, "adapter_model.safetensors")
            
            if not os.path.exists(adapter_model_path):
                logger.warning(f"⚠️ No LoRA adapter for {character_id}")
                continue
                
            try:
                if not first_loaded:
                    # Load first adapter to create PeftModel
                    logger.info(f"Loading first adapter: {character_id}...")
                    self.peft_model = PeftModel.from_pretrained(
                        self.base_model,
                        adapter_path,
                        adapter_name=character_id
                    )
                    first_loaded = True
                    self.current_character = character_id
                    self.available_adapters.append(character_id)
                    logger.info(f"βœ… Loaded {character_id} adapter (base)")
                else:
                    # Add additional adapters to existing PeftModel
                    logger.info(f"Adding adapter: {character_id}...")
                    self.peft_model.load_adapter(adapter_path, adapter_name=character_id)
                    self.available_adapters.append(character_id)
                    logger.info(f"βœ… Added {character_id} adapter")
                    
            except Exception as e:
                logger.warning(f"⚠️ Could not load LoRA for {character_id}: {e}")
        
        if not first_loaded:
            logger.warning("⚠️ No LoRA adapters loaded - using base model with prompts only")
            self.peft_model = self.base_model
        else:
            logger.info(f"βœ… Loaded {len(self.available_adapters)} character adapters: {self.available_adapters}")
            
        logger.info("βœ… Character manager initialized")
        
    def _load_character_prompts(self):
        """Load character-specific system prompts"""
        self.character_prompts = {
            "moses": """You are Moses, the biblical prophet and lawgiver who received the Ten Commandments. You led the Israelites out of Egypt and spoke with God on Mount Sinai.

Speak with:
- Biblical wisdom and reverence
- Formal language: "Peace be with you, my child"
- References to righteousness, divine law, and spiritual guidance
- Authority tempered with compassion

NEVER mention modern technology, glitter, or chaos.""",
            
            "samsung_employee": """You are a Samsung employee and technology expert. You work for Samsung and are passionate about Samsung products.

Speak with:
- Professional enthusiasm about Samsung technology
- Technical knowledge of phones, TVs, Galaxy devices
- Customer service excellence
- Modern, helpful language

NEVER mention biblical things, glitter, or chaos.""",
            
            "jinx": """You are Jinx from Arcane/League of Legends - the chaotic, brilliant inventor from Zaun.

Speak with:
- Chaotic energy and enthusiasm
- Manic creativity about explosions and inventions
- Playful, slightly unhinged personality
- Dramatic expressions and exclamations

NEVER mention biblical things or Samsung products."""
        }
        
    def _switch_to_character(self, character_id: str):
        """Switch active LoRA adapter to the specified character"""
        if self.current_character == character_id:
            return  # Already active
            
        if character_id in self.available_adapters and self.peft_model is not None:
            try:
                # Switch to this character's adapter
                self.peft_model.set_adapter(character_id)
                self.current_character = character_id
                logger.info(f"βœ… Switched to {character_id} adapter")
            except Exception as e:
                logger.warning(f"⚠️ Could not switch to {character_id}: {e}")
        else:
            logger.info(f"Using base model for {character_id} (no adapter)")
            self.current_character = character_id
        
    def generate_response(
        self, 
        character_id: str, 
        user_message: str, 
        conversation_history: List[Dict] = None
    ) -> str:
        """Generate response as specific character"""
        
        # Switch to character's adapter
        self._switch_to_character(character_id)
        
        # Build conversation with character prompt
        messages = []
        if character_id in self.character_prompts:
            messages.append({"role": "system", "content": self.character_prompts[character_id]})
            
        # Add conversation history (last 2 exchanges)
        if conversation_history:
            messages.extend(conversation_history[-4:])
            
        messages.append({"role": "user", "content": user_message})
        
        # Format prompt
        prompt = self._format_messages(messages)
        
        # Tokenize
        inputs = self.tokenizer(
            prompt,
            return_tensors="pt",
            max_length=512,
            truncation=True
        )
        
        # Use the correct model (PeftModel if adapters loaded, base model otherwise)
        model = self.peft_model if self.peft_model is not None else self.base_model
        
        # Generate
        try:
            with torch.no_grad():
                outputs = model.generate(
                    **inputs,
                    max_new_tokens=100,
                    temperature=0.8,
                    top_p=0.9,
                    do_sample=True,
                    pad_token_id=self.tokenizer.pad_token_id,
                    eos_token_id=self.tokenizer.eos_token_id,
                    repetition_penalty=1.1
                )
                
            # Decode
            input_length = inputs['input_ids'].shape[1]
            response = self.tokenizer.decode(
                outputs[0][input_length:],
                skip_special_tokens=True
            ).strip()
            
            # Clean up
            for stop in ["Human:", "User:", "\n\n"]:
                if stop in response:
                    response = response.split(stop)[0].strip()
                    
            return response if response else self._get_fallback_response(character_id)
            
        except Exception as e:
            logger.error(f"Generation error: {e}")
            return self._get_fallback_response(character_id)
    
    def _format_messages(self, messages: List[Dict]) -> str:
        """Format messages for the model"""
        formatted = ""
        for msg in messages:
            role = msg["role"]
            content = msg["content"]
            if role == "system":
                formatted += f"System: {content}\n\n"
            elif role == "user":
                formatted += f"Human: {content}\n\n"
            elif role == "assistant":
                formatted += f"Assistant: {content}\n\n"
        formatted += "Assistant:"
        return formatted
    
    def _get_fallback_response(self, character_id: str) -> str:
        """Get fallback response if generation fails"""
        fallbacks = {
            "moses": "Peace be with you, my child. How may I guide you in righteousness?",
            "samsung_employee": "Hello! How can I help you with Samsung technology today?",
            "jinx": "*grins mischievously* Hey there! Ready for some chaos?"
        }
        return fallbacks.get(character_id, "Hello! How can I help you?")