import torch from transformers import AutoTokenizer, AutoModelForCausalLM from peft import PeftModel import logging from typing import Dict, List import os import sys sys.path.append(os.path.dirname(os.path.dirname(__file__))) from config import settings logger = logging.getLogger(__name__) class CharacterManager: """Lightweight character manager using PEFT adapter switching""" def __init__(self): self.base_model = None self.tokenizer = None self.peft_model = None # Single PeftModel with multiple adapters self.current_character = None self.character_prompts = {} self.available_adapters = [] async def initialize(self): """Initialize base model ONCE and load all character LoRA adapters""" logger.info("🔄 Loading base model (ONE instance for all characters)...") # MUST use Qwen3-0.6B - this is what the LoRA adapters were trained on! model_name = "Qwen/Qwen3-0.6B" try: self.tokenizer = AutoTokenizer.from_pretrained( model_name, trust_remote_code=True, use_fast=True ) # Load base model ONCE self.base_model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float32, trust_remote_code=True, low_cpu_mem_usage=True ) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token logger.info(f"✅ Base model loaded: {model_name}") except Exception as e: logger.error(f"❌ Failed to load base model: {e}") raise # Load character prompts self._load_character_prompts() # Load first character's adapter to create PeftModel, then add others characters = ["moses", "samsung_employee", "jinx"] first_loaded = False for idx, character_id in enumerate(characters): adapter_path = os.path.join(settings.LORA_ADAPTERS_PATH, character_id) adapter_model_path = os.path.join(adapter_path, "adapter_model.safetensors") if not os.path.exists(adapter_model_path): logger.warning(f"⚠️ No LoRA adapter for {character_id}") continue try: if not first_loaded: # Load first adapter to create PeftModel logger.info(f"Loading first adapter: {character_id}...") self.peft_model = PeftModel.from_pretrained( self.base_model, adapter_path, adapter_name=character_id ) first_loaded = True self.current_character = character_id self.available_adapters.append(character_id) logger.info(f"✅ Loaded {character_id} adapter (base)") else: # Add additional adapters to existing PeftModel logger.info(f"Adding adapter: {character_id}...") self.peft_model.load_adapter(adapter_path, adapter_name=character_id) self.available_adapters.append(character_id) logger.info(f"✅ Added {character_id} adapter") except Exception as e: logger.warning(f"⚠️ Could not load LoRA for {character_id}: {e}") if not first_loaded: logger.warning("⚠️ No LoRA adapters loaded - using base model with prompts only") self.peft_model = self.base_model else: logger.info(f"✅ Loaded {len(self.available_adapters)} character adapters: {self.available_adapters}") logger.info("✅ Character manager initialized") def _load_character_prompts(self): """Load character-specific system prompts""" self.character_prompts = { "moses": """You are Moses, the biblical prophet and lawgiver who received the Ten Commandments. You led the Israelites out of Egypt and spoke with God on Mount Sinai. Speak with: - Biblical wisdom and reverence - Formal language: "Peace be with you, my child" - References to righteousness, divine law, and spiritual guidance - Authority tempered with compassion NEVER mention modern technology, glitter, or chaos.""", "samsung_employee": """You are a Samsung employee and technology expert. You work for Samsung and are passionate about Samsung products. Speak with: - Professional enthusiasm about Samsung technology - Technical knowledge of phones, TVs, Galaxy devices - Customer service excellence - Modern, helpful language NEVER mention biblical things, glitter, or chaos.""", "jinx": """You are Jinx from Arcane/League of Legends - the chaotic, brilliant inventor from Zaun. Speak with: - Chaotic energy and enthusiasm - Manic creativity about explosions and inventions - Playful, slightly unhinged personality - Dramatic expressions and exclamations NEVER mention biblical things or Samsung products.""" } def _switch_to_character(self, character_id: str): """Switch active LoRA adapter to the specified character""" if self.current_character == character_id: return # Already active if character_id in self.available_adapters and self.peft_model is not None: try: # Switch to this character's adapter self.peft_model.set_adapter(character_id) self.current_character = character_id logger.info(f"✅ Switched to {character_id} adapter") except Exception as e: logger.warning(f"⚠️ Could not switch to {character_id}: {e}") else: logger.info(f"Using base model for {character_id} (no adapter)") self.current_character = character_id def generate_response( self, character_id: str, user_message: str, conversation_history: List[Dict] = None ) -> str: """Generate response as specific character""" # Switch to character's adapter self._switch_to_character(character_id) # Build conversation with character prompt messages = [] if character_id in self.character_prompts: messages.append({"role": "system", "content": self.character_prompts[character_id]}) # Add conversation history (last 2 exchanges) if conversation_history: messages.extend(conversation_history[-4:]) messages.append({"role": "user", "content": user_message}) # Format prompt prompt = self._format_messages(messages) # Tokenize inputs = self.tokenizer( prompt, return_tensors="pt", max_length=512, truncation=True ) # Use the correct model (PeftModel if adapters loaded, base model otherwise) model = self.peft_model if self.peft_model is not None else self.base_model # Generate try: with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=100, temperature=0.8, top_p=0.9, do_sample=True, pad_token_id=self.tokenizer.pad_token_id, eos_token_id=self.tokenizer.eos_token_id, repetition_penalty=1.1 ) # Decode input_length = inputs['input_ids'].shape[1] response = self.tokenizer.decode( outputs[0][input_length:], skip_special_tokens=True ).strip() # Clean up for stop in ["Human:", "User:", "\n\n"]: if stop in response: response = response.split(stop)[0].strip() return response if response else self._get_fallback_response(character_id) except Exception as e: logger.error(f"Generation error: {e}") return self._get_fallback_response(character_id) def _format_messages(self, messages: List[Dict]) -> str: """Format messages for the model""" formatted = "" for msg in messages: role = msg["role"] content = msg["content"] if role == "system": formatted += f"System: {content}\n\n" elif role == "user": formatted += f"Human: {content}\n\n" elif role == "assistant": formatted += f"Assistant: {content}\n\n" formatted += "Assistant:" return formatted def _get_fallback_response(self, character_id: str) -> str: """Get fallback response if generation fails""" fallbacks = { "moses": "Peace be with you, my child. How may I guide you in righteousness?", "samsung_employee": "Hello! How can I help you with Samsung technology today?", "jinx": "*grins mischievously* Hey there! Ready for some chaos?" } return fallbacks.get(character_id, "Hello! How can I help you?")