Spaces:
Sleeping
Sleeping
File size: 9,642 Bytes
de2021f f92a42b de2021f f92a42b de2021f f92a42b de2021f f92a42b de2021f f92a42b de2021f f92a42b de2021f f92a42b de2021f f92a42b de2021f f92a42b de2021f f92a42b de2021f f92a42b de2021f f92a42b de2021f f92a42b de2021f f92a42b de2021f f92a42b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 |
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import logging
from typing import Dict, List
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from config import settings
logger = logging.getLogger(__name__)
class CharacterManager:
"""Lightweight character manager using PEFT adapter switching"""
def __init__(self):
self.base_model = None
self.tokenizer = None
self.peft_model = None # Single PeftModel with multiple adapters
self.current_character = None
self.character_prompts = {}
self.available_adapters = []
async def initialize(self):
"""Initialize base model ONCE and load all character LoRA adapters"""
logger.info("π Loading base model (ONE instance for all characters)...")
# MUST use Qwen3-0.6B - this is what the LoRA adapters were trained on!
model_name = "Qwen/Qwen3-0.6B"
try:
self.tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True,
use_fast=True
)
# Load base model ONCE
self.base_model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float32,
trust_remote_code=True,
low_cpu_mem_usage=True
)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
logger.info(f"β
Base model loaded: {model_name}")
except Exception as e:
logger.error(f"β Failed to load base model: {e}")
raise
# Load character prompts
self._load_character_prompts()
# Load first character's adapter to create PeftModel, then add others
characters = ["moses", "samsung_employee", "jinx"]
first_loaded = False
for idx, character_id in enumerate(characters):
adapter_path = os.path.join(settings.LORA_ADAPTERS_PATH, character_id)
adapter_model_path = os.path.join(adapter_path, "adapter_model.safetensors")
if not os.path.exists(adapter_model_path):
logger.warning(f"β οΈ No LoRA adapter for {character_id}")
continue
try:
if not first_loaded:
# Load first adapter to create PeftModel
logger.info(f"Loading first adapter: {character_id}...")
self.peft_model = PeftModel.from_pretrained(
self.base_model,
adapter_path,
adapter_name=character_id
)
first_loaded = True
self.current_character = character_id
self.available_adapters.append(character_id)
logger.info(f"β
Loaded {character_id} adapter (base)")
else:
# Add additional adapters to existing PeftModel
logger.info(f"Adding adapter: {character_id}...")
self.peft_model.load_adapter(adapter_path, adapter_name=character_id)
self.available_adapters.append(character_id)
logger.info(f"β
Added {character_id} adapter")
except Exception as e:
logger.warning(f"β οΈ Could not load LoRA for {character_id}: {e}")
if not first_loaded:
logger.warning("β οΈ No LoRA adapters loaded - using base model with prompts only")
self.peft_model = self.base_model
else:
logger.info(f"β
Loaded {len(self.available_adapters)} character adapters: {self.available_adapters}")
logger.info("β
Character manager initialized")
def _load_character_prompts(self):
"""Load character-specific system prompts"""
self.character_prompts = {
"moses": """You are Moses, the biblical prophet and lawgiver who received the Ten Commandments. You led the Israelites out of Egypt and spoke with God on Mount Sinai.
Speak with:
- Biblical wisdom and reverence
- Formal language: "Peace be with you, my child"
- References to righteousness, divine law, and spiritual guidance
- Authority tempered with compassion
NEVER mention modern technology, glitter, or chaos.""",
"samsung_employee": """You are a Samsung employee and technology expert. You work for Samsung and are passionate about Samsung products.
Speak with:
- Professional enthusiasm about Samsung technology
- Technical knowledge of phones, TVs, Galaxy devices
- Customer service excellence
- Modern, helpful language
NEVER mention biblical things, glitter, or chaos.""",
"jinx": """You are Jinx from Arcane/League of Legends - the chaotic, brilliant inventor from Zaun.
Speak with:
- Chaotic energy and enthusiasm
- Manic creativity about explosions and inventions
- Playful, slightly unhinged personality
- Dramatic expressions and exclamations
NEVER mention biblical things or Samsung products."""
}
def _switch_to_character(self, character_id: str):
"""Switch active LoRA adapter to the specified character"""
if self.current_character == character_id:
return # Already active
if character_id in self.available_adapters and self.peft_model is not None:
try:
# Switch to this character's adapter
self.peft_model.set_adapter(character_id)
self.current_character = character_id
logger.info(f"β
Switched to {character_id} adapter")
except Exception as e:
logger.warning(f"β οΈ Could not switch to {character_id}: {e}")
else:
logger.info(f"Using base model for {character_id} (no adapter)")
self.current_character = character_id
def generate_response(
self,
character_id: str,
user_message: str,
conversation_history: List[Dict] = None
) -> str:
"""Generate response as specific character"""
# Switch to character's adapter
self._switch_to_character(character_id)
# Build conversation with character prompt
messages = []
if character_id in self.character_prompts:
messages.append({"role": "system", "content": self.character_prompts[character_id]})
# Add conversation history (last 2 exchanges)
if conversation_history:
messages.extend(conversation_history[-4:])
messages.append({"role": "user", "content": user_message})
# Format prompt
prompt = self._format_messages(messages)
# Tokenize
inputs = self.tokenizer(
prompt,
return_tensors="pt",
max_length=512,
truncation=True
)
# Use the correct model (PeftModel if adapters loaded, base model otherwise)
model = self.peft_model if self.peft_model is not None else self.base_model
# Generate
try:
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=100,
temperature=0.8,
top_p=0.9,
do_sample=True,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id,
repetition_penalty=1.1
)
# Decode
input_length = inputs['input_ids'].shape[1]
response = self.tokenizer.decode(
outputs[0][input_length:],
skip_special_tokens=True
).strip()
# Clean up
for stop in ["Human:", "User:", "\n\n"]:
if stop in response:
response = response.split(stop)[0].strip()
return response if response else self._get_fallback_response(character_id)
except Exception as e:
logger.error(f"Generation error: {e}")
return self._get_fallback_response(character_id)
def _format_messages(self, messages: List[Dict]) -> str:
"""Format messages for the model"""
formatted = ""
for msg in messages:
role = msg["role"]
content = msg["content"]
if role == "system":
formatted += f"System: {content}\n\n"
elif role == "user":
formatted += f"Human: {content}\n\n"
elif role == "assistant":
formatted += f"Assistant: {content}\n\n"
formatted += "Assistant:"
return formatted
def _get_fallback_response(self, character_id: str) -> str:
"""Get fallback response if generation fails"""
fallbacks = {
"moses": "Peace be with you, my child. How may I guide you in righteousness?",
"samsung_employee": "Hello! How can I help you with Samsung technology today?",
"jinx": "*grins mischievously* Hey there! Ready for some chaos?"
}
return fallbacks.get(character_id, "Hello! How can I help you?")
|