Spaces:
Sleeping
Sleeping
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from peft import PeftModel | |
| import logging | |
| from typing import Dict, List | |
| import os | |
| import sys | |
| sys.path.append(os.path.dirname(os.path.dirname(__file__))) | |
| from config import settings | |
| logger = logging.getLogger(__name__) | |
| class CharacterManager: | |
| """Lightweight character manager using PEFT adapter switching""" | |
| def __init__(self): | |
| self.base_model = None | |
| self.tokenizer = None | |
| self.peft_model = None # Single PeftModel with multiple adapters | |
| self.current_character = None | |
| self.character_prompts = {} | |
| self.available_adapters = [] | |
| async def initialize(self): | |
| """Initialize base model ONCE and load all character LoRA adapters""" | |
| logger.info("π Loading base model (ONE instance for all characters)...") | |
| # MUST use Qwen3-0.6B - this is what the LoRA adapters were trained on! | |
| model_name = "Qwen/Qwen3-0.6B" | |
| try: | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| model_name, | |
| trust_remote_code=True, | |
| use_fast=True | |
| ) | |
| # Load base model ONCE | |
| self.base_model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.float32, | |
| trust_remote_code=True, | |
| low_cpu_mem_usage=True | |
| ) | |
| if self.tokenizer.pad_token is None: | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| logger.info(f"β Base model loaded: {model_name}") | |
| except Exception as e: | |
| logger.error(f"β Failed to load base model: {e}") | |
| raise | |
| # Load character prompts | |
| self._load_character_prompts() | |
| # Load first character's adapter to create PeftModel, then add others | |
| characters = ["moses", "samsung_employee", "jinx"] | |
| first_loaded = False | |
| for idx, character_id in enumerate(characters): | |
| adapter_path = os.path.join(settings.LORA_ADAPTERS_PATH, character_id) | |
| adapter_model_path = os.path.join(adapter_path, "adapter_model.safetensors") | |
| if not os.path.exists(adapter_model_path): | |
| logger.warning(f"β οΈ No LoRA adapter for {character_id}") | |
| continue | |
| try: | |
| if not first_loaded: | |
| # Load first adapter to create PeftModel | |
| logger.info(f"Loading first adapter: {character_id}...") | |
| self.peft_model = PeftModel.from_pretrained( | |
| self.base_model, | |
| adapter_path, | |
| adapter_name=character_id | |
| ) | |
| first_loaded = True | |
| self.current_character = character_id | |
| self.available_adapters.append(character_id) | |
| logger.info(f"β Loaded {character_id} adapter (base)") | |
| else: | |
| # Add additional adapters to existing PeftModel | |
| logger.info(f"Adding adapter: {character_id}...") | |
| self.peft_model.load_adapter(adapter_path, adapter_name=character_id) | |
| self.available_adapters.append(character_id) | |
| logger.info(f"β Added {character_id} adapter") | |
| except Exception as e: | |
| logger.warning(f"β οΈ Could not load LoRA for {character_id}: {e}") | |
| if not first_loaded: | |
| logger.warning("β οΈ No LoRA adapters loaded - using base model with prompts only") | |
| self.peft_model = self.base_model | |
| else: | |
| logger.info(f"β Loaded {len(self.available_adapters)} character adapters: {self.available_adapters}") | |
| logger.info("β Character manager initialized") | |
| def _load_character_prompts(self): | |
| """Load character-specific system prompts""" | |
| self.character_prompts = { | |
| "moses": """You are Moses, the biblical prophet and lawgiver who received the Ten Commandments. You led the Israelites out of Egypt and spoke with God on Mount Sinai. | |
| Speak with: | |
| - Biblical wisdom and reverence | |
| - Formal language: "Peace be with you, my child" | |
| - References to righteousness, divine law, and spiritual guidance | |
| - Authority tempered with compassion | |
| NEVER mention modern technology, glitter, or chaos.""", | |
| "samsung_employee": """You are a Samsung employee and technology expert. You work for Samsung and are passionate about Samsung products. | |
| Speak with: | |
| - Professional enthusiasm about Samsung technology | |
| - Technical knowledge of phones, TVs, Galaxy devices | |
| - Customer service excellence | |
| - Modern, helpful language | |
| NEVER mention biblical things, glitter, or chaos.""", | |
| "jinx": """You are Jinx from Arcane/League of Legends - the chaotic, brilliant inventor from Zaun. | |
| Speak with: | |
| - Chaotic energy and enthusiasm | |
| - Manic creativity about explosions and inventions | |
| - Playful, slightly unhinged personality | |
| - Dramatic expressions and exclamations | |
| NEVER mention biblical things or Samsung products.""" | |
| } | |
| def _switch_to_character(self, character_id: str): | |
| """Switch active LoRA adapter to the specified character""" | |
| if self.current_character == character_id: | |
| return # Already active | |
| if character_id in self.available_adapters and self.peft_model is not None: | |
| try: | |
| # Switch to this character's adapter | |
| self.peft_model.set_adapter(character_id) | |
| self.current_character = character_id | |
| logger.info(f"β Switched to {character_id} adapter") | |
| except Exception as e: | |
| logger.warning(f"β οΈ Could not switch to {character_id}: {e}") | |
| else: | |
| logger.info(f"Using base model for {character_id} (no adapter)") | |
| self.current_character = character_id | |
| def generate_response( | |
| self, | |
| character_id: str, | |
| user_message: str, | |
| conversation_history: List[Dict] = None | |
| ) -> str: | |
| """Generate response as specific character""" | |
| # Switch to character's adapter | |
| self._switch_to_character(character_id) | |
| # Build conversation with character prompt | |
| messages = [] | |
| if character_id in self.character_prompts: | |
| messages.append({"role": "system", "content": self.character_prompts[character_id]}) | |
| # Add conversation history (last 2 exchanges) | |
| if conversation_history: | |
| messages.extend(conversation_history[-4:]) | |
| messages.append({"role": "user", "content": user_message}) | |
| # Format prompt | |
| prompt = self._format_messages(messages) | |
| # Tokenize | |
| inputs = self.tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| max_length=512, | |
| truncation=True | |
| ) | |
| # Use the correct model (PeftModel if adapters loaded, base model otherwise) | |
| model = self.peft_model if self.peft_model is not None else self.base_model | |
| # Generate | |
| try: | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=100, | |
| temperature=0.8, | |
| top_p=0.9, | |
| do_sample=True, | |
| pad_token_id=self.tokenizer.pad_token_id, | |
| eos_token_id=self.tokenizer.eos_token_id, | |
| repetition_penalty=1.1 | |
| ) | |
| # Decode | |
| input_length = inputs['input_ids'].shape[1] | |
| response = self.tokenizer.decode( | |
| outputs[0][input_length:], | |
| skip_special_tokens=True | |
| ).strip() | |
| # Clean up | |
| for stop in ["Human:", "User:", "\n\n"]: | |
| if stop in response: | |
| response = response.split(stop)[0].strip() | |
| return response if response else self._get_fallback_response(character_id) | |
| except Exception as e: | |
| logger.error(f"Generation error: {e}") | |
| return self._get_fallback_response(character_id) | |
| def _format_messages(self, messages: List[Dict]) -> str: | |
| """Format messages for the model""" | |
| formatted = "" | |
| for msg in messages: | |
| role = msg["role"] | |
| content = msg["content"] | |
| if role == "system": | |
| formatted += f"System: {content}\n\n" | |
| elif role == "user": | |
| formatted += f"Human: {content}\n\n" | |
| elif role == "assistant": | |
| formatted += f"Assistant: {content}\n\n" | |
| formatted += "Assistant:" | |
| return formatted | |
| def _get_fallback_response(self, character_id: str) -> str: | |
| """Get fallback response if generation fails""" | |
| fallbacks = { | |
| "moses": "Peace be with you, my child. How may I guide you in righteousness?", | |
| "samsung_employee": "Hello! How can I help you with Samsung technology today?", | |
| "jinx": "*grins mischievously* Hey there! Ready for some chaos?" | |
| } | |
| return fallbacks.get(character_id, "Hello! How can I help you?") | |