ButterM40 commited on
Commit
f92a42b
·
1 Parent(s): de2021f

Fix: Use Qwen3-0.6B (correct model) with proper PEFT adapter switching via set_adapter()

Browse files
backend/models/lightweight_character_manager.py CHANGED
@@ -1,6 +1,6 @@
1
  import torch
2
- from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
3
- from peft import PeftModel, PeftConfig, set_peft_model_state_dict, get_peft_model_state_dict
4
  import logging
5
  from typing import Dict, List
6
  import os
@@ -11,20 +11,22 @@ from config import settings
11
  logger = logging.getLogger(__name__)
12
 
13
  class CharacterManager:
14
- """Lightweight character manager that swaps LoRA adapters on a single base model"""
15
 
16
  def __init__(self):
17
  self.base_model = None
18
  self.tokenizer = None
 
19
  self.current_character = None
20
- self.character_adapters = {} # Store adapter weights, not full models
21
  self.character_prompts = {}
 
22
 
23
  async def initialize(self):
24
  """Initialize base model ONCE and load all character LoRA adapters"""
25
  logger.info("🔄 Loading base model (ONE instance for all characters)...")
26
 
27
- model_name = "Qwen/Qwen2.5-0.5B-Instruct" # Smaller model for HF Spaces
 
28
 
29
  try:
30
  self.tokenizer = AutoTokenizer.from_pretrained(
@@ -33,7 +35,7 @@ class CharacterManager:
33
  use_fast=True
34
  )
35
 
36
- # Load base model ONCE (CPU for HF Spaces free tier)
37
  self.base_model = AutoModelForCausalLM.from_pretrained(
38
  model_name,
39
  torch_dtype=torch.float32,
@@ -53,9 +55,46 @@ class CharacterManager:
53
  # Load character prompts
54
  self._load_character_prompts()
55
 
56
- # Try to load LoRA adapters (optional - graceful degradation)
57
- for character_id in ["moses", "samsung_employee", "jinx"]:
58
- await self._load_character_adapter(character_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
  logger.info("✅ Character manager initialized")
61
 
@@ -93,54 +132,22 @@ Speak with:
93
  NEVER mention biblical things or Samsung products."""
94
  }
95
 
96
- async def _load_character_adapter(self, character_id: str):
97
- """Try to load LoRA adapter weights (graceful failure if missing)"""
98
- adapter_path = os.path.join(settings.LORA_ADAPTERS_PATH, character_id)
99
- adapter_model_path = os.path.join(adapter_path, "adapter_model.safetensors")
100
-
101
- if not os.path.exists(adapter_model_path):
102
- logger.warning(f"⚠️ No LoRA adapter for {character_id} - will use prompts only")
103
- return
104
-
105
- try:
106
- logger.info(f"Loading LoRA adapter for {character_id}...")
107
-
108
- # Load adapter onto base model temporarily
109
- model_with_adapter = PeftModel.from_pretrained(
110
- self.base_model,
111
- adapter_path,
112
- adapter_name=character_id
113
- )
114
-
115
- # Extract and store just the adapter weights (tiny!)
116
- self.character_adapters[character_id] = get_peft_model_state_dict(model_with_adapter)
117
-
118
- # Clean up - we only need the weights
119
- del model_with_adapter
120
- torch.cuda.empty_cache() if torch.cuda.is_available() else None
121
-
122
- logger.info(f"✅ Loaded LoRA adapter for {character_id}")
123
-
124
- except Exception as e:
125
- logger.warning(f"⚠️ Could not load LoRA for {character_id}: {e}")
126
- logger.info(f"Will use system prompts only for {character_id}")
127
-
128
  def _switch_to_character(self, character_id: str):
129
- """Switch to a character by loading their LoRA adapter (if available)"""
130
  if self.current_character == character_id:
131
- return # Already loaded
132
 
133
- # If character has LoRA adapter, apply it
134
- if character_id in self.character_adapters:
135
  try:
136
- # Create PeftModel with this character's adapter
137
- self.base_model = PeftModel(self.base_model, character_id)
138
- set_peft_model_state_dict(self.base_model, self.character_adapters[character_id])
139
- logger.info(f"✅ Switched to {character_id} with LoRA")
140
- except:
141
- logger.warning(f"⚠️ Using base model + prompts for {character_id}")
142
-
143
- self.current_character = character_id
 
144
 
145
  def generate_response(
146
  self,
@@ -150,7 +157,7 @@ NEVER mention biblical things or Samsung products."""
150
  ) -> str:
151
  """Generate response as specific character"""
152
 
153
- # Switch to character (applies LoRA if available)
154
  self._switch_to_character(character_id)
155
 
156
  # Build conversation with character prompt
@@ -175,10 +182,13 @@ NEVER mention biblical things or Samsung products."""
175
  truncation=True
176
  )
177
 
 
 
 
178
  # Generate
179
  try:
180
  with torch.no_grad():
181
- outputs = self.base_model.generate(
182
  **inputs,
183
  max_new_tokens=100,
184
  temperature=0.8,
@@ -230,3 +240,4 @@ NEVER mention biblical things or Samsung products."""
230
  "jinx": "*grins mischievously* Hey there! Ready for some chaos?"
231
  }
232
  return fallbacks.get(character_id, "Hello! How can I help you?")
 
 
1
  import torch
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ from peft import PeftModel
4
  import logging
5
  from typing import Dict, List
6
  import os
 
11
  logger = logging.getLogger(__name__)
12
 
13
  class CharacterManager:
14
+ """Lightweight character manager using PEFT adapter switching"""
15
 
16
  def __init__(self):
17
  self.base_model = None
18
  self.tokenizer = None
19
+ self.peft_model = None # Single PeftModel with multiple adapters
20
  self.current_character = None
 
21
  self.character_prompts = {}
22
+ self.available_adapters = []
23
 
24
  async def initialize(self):
25
  """Initialize base model ONCE and load all character LoRA adapters"""
26
  logger.info("🔄 Loading base model (ONE instance for all characters)...")
27
 
28
+ # MUST use Qwen3-0.6B - this is what the LoRA adapters were trained on!
29
+ model_name = "Qwen/Qwen3-0.6B"
30
 
31
  try:
32
  self.tokenizer = AutoTokenizer.from_pretrained(
 
35
  use_fast=True
36
  )
37
 
38
+ # Load base model ONCE
39
  self.base_model = AutoModelForCausalLM.from_pretrained(
40
  model_name,
41
  torch_dtype=torch.float32,
 
55
  # Load character prompts
56
  self._load_character_prompts()
57
 
58
+ # Load first character's adapter to create PeftModel, then add others
59
+ characters = ["moses", "samsung_employee", "jinx"]
60
+ first_loaded = False
61
+
62
+ for idx, character_id in enumerate(characters):
63
+ adapter_path = os.path.join(settings.LORA_ADAPTERS_PATH, character_id)
64
+ adapter_model_path = os.path.join(adapter_path, "adapter_model.safetensors")
65
+
66
+ if not os.path.exists(adapter_model_path):
67
+ logger.warning(f"⚠️ No LoRA adapter for {character_id}")
68
+ continue
69
+
70
+ try:
71
+ if not first_loaded:
72
+ # Load first adapter to create PeftModel
73
+ logger.info(f"Loading first adapter: {character_id}...")
74
+ self.peft_model = PeftModel.from_pretrained(
75
+ self.base_model,
76
+ adapter_path,
77
+ adapter_name=character_id
78
+ )
79
+ first_loaded = True
80
+ self.current_character = character_id
81
+ self.available_adapters.append(character_id)
82
+ logger.info(f"✅ Loaded {character_id} adapter (base)")
83
+ else:
84
+ # Add additional adapters to existing PeftModel
85
+ logger.info(f"Adding adapter: {character_id}...")
86
+ self.peft_model.load_adapter(adapter_path, adapter_name=character_id)
87
+ self.available_adapters.append(character_id)
88
+ logger.info(f"✅ Added {character_id} adapter")
89
+
90
+ except Exception as e:
91
+ logger.warning(f"⚠️ Could not load LoRA for {character_id}: {e}")
92
+
93
+ if not first_loaded:
94
+ logger.warning("⚠️ No LoRA adapters loaded - using base model with prompts only")
95
+ self.peft_model = self.base_model
96
+ else:
97
+ logger.info(f"✅ Loaded {len(self.available_adapters)} character adapters: {self.available_adapters}")
98
 
99
  logger.info("✅ Character manager initialized")
100
 
 
132
  NEVER mention biblical things or Samsung products."""
133
  }
134
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  def _switch_to_character(self, character_id: str):
136
+ """Switch active LoRA adapter to the specified character"""
137
  if self.current_character == character_id:
138
+ return # Already active
139
 
140
+ if character_id in self.available_adapters and self.peft_model is not None:
 
141
  try:
142
+ # Switch to this character's adapter
143
+ self.peft_model.set_adapter(character_id)
144
+ self.current_character = character_id
145
+ logger.info(f"✅ Switched to {character_id} adapter")
146
+ except Exception as e:
147
+ logger.warning(f"⚠️ Could not switch to {character_id}: {e}")
148
+ else:
149
+ logger.info(f"Using base model for {character_id} (no adapter)")
150
+ self.current_character = character_id
151
 
152
  def generate_response(
153
  self,
 
157
  ) -> str:
158
  """Generate response as specific character"""
159
 
160
+ # Switch to character's adapter
161
  self._switch_to_character(character_id)
162
 
163
  # Build conversation with character prompt
 
182
  truncation=True
183
  )
184
 
185
+ # Use the correct model (PeftModel if adapters loaded, base model otherwise)
186
+ model = self.peft_model if self.peft_model is not None else self.base_model
187
+
188
  # Generate
189
  try:
190
  with torch.no_grad():
191
+ outputs = model.generate(
192
  **inputs,
193
  max_new_tokens=100,
194
  temperature=0.8,
 
240
  "jinx": "*grins mischievously* Hey there! Ready for some chaos?"
241
  }
242
  return fallbacks.get(character_id, "Hello! How can I help you?")
243
+