ianshank commited on
Commit
fcccb53
·
1 Parent(s): 414dadb

fix: rename to bert_controller_v2.py to force cache invalidation

Browse files
app.py CHANGED
@@ -36,7 +36,8 @@ from src.agents.meta_controller.base import MetaControllerFeatures
36
 
37
  # Robust import for BERTMetaController
38
  try:
39
- from src.agents.meta_controller.bert_controller import BERTMetaController
 
40
  except ImportError as e:
41
  print(f"CRITICAL WARNING: Failed to import BERTMetaController: {e}")
42
  print("Falling back to mock BERTMetaController to prevent crash.")
 
36
 
37
  # Robust import for BERTMetaController
38
  try:
39
+ # V2 import to bust cache
40
+ from src.agents.meta_controller.bert_controller_v2 import BERTMetaController
41
  except ImportError as e:
42
  print(f"CRITICAL WARNING: Failed to import BERTMetaController: {e}")
43
  print("Falling back to mock BERTMetaController to prevent crash.")
src/agents/meta_controller/bert_controller_v2.py ADDED
@@ -0,0 +1,423 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ BERT-based Meta-Controller with LoRA adapters for efficient fine-tuning.
3
+
4
+ This module provides a BERT-based meta-controller that uses Low-Rank Adaptation (LoRA)
5
+ for parameter-efficient fine-tuning. The controller converts agent state features into
6
+ text and uses a sequence classification model to predict the optimal agent.
7
+ """
8
+
9
+ import warnings
10
+ from typing import Any
11
+
12
+ import torch
13
+
14
+ from src.agents.meta_controller.base import (
15
+ AbstractMetaController,
16
+ MetaControllerFeatures,
17
+ MetaControllerPrediction,
18
+ )
19
+ from src.agents.meta_controller.utils import features_to_text
20
+
21
+ # Handle optional transformers and peft imports gracefully
22
+ _TRANSFORMERS_AVAILABLE = False
23
+ _PEFT_AVAILABLE = False
24
+
25
+ try:
26
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
27
+
28
+ _TRANSFORMERS_AVAILABLE = True
29
+ except ImportError:
30
+ warnings.warn(
31
+ "transformers library not installed. Install it with: pip install transformers",
32
+ ImportWarning,
33
+ stacklevel=2,
34
+ )
35
+ AutoTokenizer = None # type: ignore
36
+ AutoModelForSequenceClassification = None # type: ignore
37
+
38
+ try:
39
+ from peft import LoraConfig, TaskType, get_peft_model
40
+
41
+ _PEFT_AVAILABLE = True
42
+ except ImportError:
43
+ # Fallback if peft is missing or broken (e.g. version mismatch with transformers)
44
+ _PEFT_AVAILABLE = False
45
+ LoraConfig = None # type: ignore
46
+ TaskType = None # type: ignore
47
+ get_peft_model = None # type: ignore
48
+
49
+
50
+ class BERTMetaController(AbstractMetaController):
51
+ """
52
+ BERT-based meta-controller with optional LoRA adapters for efficient fine-tuning.
53
+
54
+ This controller converts agent state features into structured text and uses
55
+ a pre-trained BERT model (with optional LoRA adapters) to classify which
56
+ agent should handle the current query. LoRA enables parameter-efficient
57
+ fine-tuning by only training low-rank decomposition matrices.
58
+
59
+ Attributes:
60
+ DEFAULT_MODEL_NAME: Default BERT model to use.
61
+ NUM_LABELS: Number of output labels (agents to choose from).
62
+ device: PyTorch device for tensor operations.
63
+ model_name: Name of the pre-trained model.
64
+ lora_r: LoRA rank parameter.
65
+ lora_alpha: LoRA alpha scaling parameter.
66
+ lora_dropout: LoRA dropout rate.
67
+ use_lora: Whether to use LoRA adapters.
68
+ tokenizer: BERT tokenizer for text processing.
69
+ model: BERT sequence classification model (with or without LoRA).
70
+
71
+ Example:
72
+ >>> controller = BERTMetaController(name="BERTController", seed=42)
73
+ >>> features = MetaControllerFeatures(
74
+ ... hrm_confidence=0.8,
75
+ ... trm_confidence=0.6,
76
+ ... mcts_value=0.75,
77
+ ... consensus_score=0.7,
78
+ ... last_agent='hrm',
79
+ ... iteration=2,
80
+ ... query_length=150,
81
+ ... has_rag_context=True
82
+ ... )
83
+ >>> prediction = controller.predict(features)
84
+ >>> prediction.agent in ['hrm', 'trm', 'mcts']
85
+ True
86
+ >>> 0.0 <= prediction.confidence <= 1.0
87
+ True
88
+ """
89
+
90
+ DEFAULT_MODEL_NAME = "prajjwal1/bert-mini"
91
+ NUM_LABELS = 3
92
+
93
+ def __init__(
94
+ self,
95
+ name: str = "BERTMetaController",
96
+ seed: int = 42,
97
+ model_name: str | None = None,
98
+ lora_r: int = 4,
99
+ lora_alpha: int = 16,
100
+ lora_dropout: float = 0.1,
101
+ device: str | None = None,
102
+ use_lora: bool = True,
103
+ ) -> None:
104
+ """
105
+ Initialize the BERT meta-controller with optional LoRA adapters.
106
+
107
+ Args:
108
+ name: Name identifier for this controller. Defaults to "BERTMetaController".
109
+ seed: Random seed for reproducibility. Defaults to 42.
110
+ model_name: Pre-trained model name from HuggingFace. If None, uses DEFAULT_MODEL_NAME.
111
+ lora_r: LoRA rank parameter (lower = more compression). Defaults to 4.
112
+ lora_alpha: LoRA alpha scaling parameter. Defaults to 16.
113
+ lora_dropout: Dropout rate for LoRA layers. Defaults to 0.1.
114
+ device: Device to run model on ('cpu', 'cuda', 'mps', etc.).
115
+ If None, auto-detects best available device.
116
+ use_lora: Whether to apply LoRA adapters to the model. Defaults to True.
117
+
118
+ Raises:
119
+ ImportError: If transformers library is not installed.
120
+ ImportError: If use_lora is True and peft library is not installed.
121
+
122
+ Example:
123
+ >>> controller = BERTMetaController(
124
+ ... name="CustomBERT",
125
+ ... seed=123,
126
+ ... lora_r=8,
127
+ ... lora_alpha=32,
128
+ ... use_lora=True
129
+ ... )
130
+ """
131
+ super().__init__(name=name, seed=seed)
132
+
133
+ # Check for required dependencies
134
+ if not _TRANSFORMERS_AVAILABLE:
135
+ raise ImportError(
136
+ "transformers library is required for BERTMetaController. Install it with: pip install transformers"
137
+ )
138
+
139
+ if use_lora and not _PEFT_AVAILABLE:
140
+ raise ImportError("peft library is required for LoRA support. Install it with: pip install peft")
141
+
142
+ # Set random seed for reproducibility
143
+ torch.manual_seed(seed)
144
+
145
+ # Auto-detect device if not specified
146
+ if device is None:
147
+ if torch.cuda.is_available():
148
+ self.device = torch.device("cuda")
149
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
150
+ self.device = torch.device("mps")
151
+ else:
152
+ self.device = torch.device("cpu")
153
+ else:
154
+ self.device = torch.device(device)
155
+
156
+ # Store configuration parameters
157
+ self.model_name = model_name if model_name is not None else self.DEFAULT_MODEL_NAME
158
+ self.lora_r = lora_r
159
+ self.lora_alpha = lora_alpha
160
+ self.lora_dropout = lora_dropout
161
+ self.use_lora = use_lora
162
+
163
+ # Initialize tokenizer
164
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
165
+
166
+ # Initialize base model for sequence classification
167
+ base_model = AutoModelForSequenceClassification.from_pretrained(self.model_name, num_labels=self.NUM_LABELS)
168
+
169
+ # Apply LoRA adapters if requested
170
+ if self.use_lora:
171
+ lora_config = LoraConfig(
172
+ task_type=TaskType.SEQ_CLS,
173
+ r=self.lora_r,
174
+ lora_alpha=self.lora_alpha,
175
+ lora_dropout=self.lora_dropout,
176
+ target_modules=["query", "value"],
177
+ )
178
+ self.model = get_peft_model(base_model, lora_config)
179
+ else:
180
+ self.model = base_model
181
+
182
+ # Move model to device
183
+ self.model = self.model.to(self.device)
184
+
185
+ # Set model to evaluation mode
186
+ self.model.eval()
187
+
188
+ # Initialize tokenization cache for performance optimization
189
+ self._tokenization_cache: dict[str, Any] = {}
190
+
191
+ def predict(self, features: MetaControllerFeatures) -> MetaControllerPrediction:
192
+ """
193
+ Predict which agent should handle the current query.
194
+
195
+ Converts features to structured text, tokenizes the text, runs through
196
+ the BERT model, and returns a prediction with confidence scores.
197
+
198
+ Args:
199
+ features: Features extracted from the current agent state.
200
+
201
+ Returns:
202
+ Prediction containing the selected agent, confidence score,
203
+ and probability distribution over all agents.
204
+
205
+ Example:
206
+ >>> controller = BERTMetaController()
207
+ >>> features = MetaControllerFeatures(
208
+ ... hrm_confidence=0.9,
209
+ ... trm_confidence=0.3,
210
+ ... mcts_value=0.5,
211
+ ... consensus_score=0.8,
212
+ ... last_agent='none',
213
+ ... iteration=0,
214
+ ... query_length=100,
215
+ ... has_rag_context=False
216
+ ... )
217
+ >>> pred = controller.predict(features)
218
+ >>> isinstance(pred.agent, str)
219
+ >>> isinstance(pred.confidence, float)
220
+ >>> len(pred.probabilities) == 3
221
+ """
222
+ # Convert features to structured text
223
+ text = features_to_text(features)
224
+
225
+ # Check cache for tokenized text
226
+ if text in self._tokenization_cache:
227
+ inputs = self._tokenization_cache[text]
228
+ else:
229
+ # Tokenize the text
230
+ inputs = self.tokenizer(
231
+ text,
232
+ return_tensors="pt",
233
+ padding=True,
234
+ truncation=True,
235
+ max_length=512,
236
+ )
237
+ # Cache the tokenized result
238
+ self._tokenization_cache[text] = inputs
239
+
240
+ # Move inputs to device
241
+ inputs = {key: value.to(self.device) for key, value in inputs.items()}
242
+
243
+ # Perform inference without gradient tracking
244
+ with torch.no_grad():
245
+ # Get logits from model
246
+ outputs = self.model(**inputs)
247
+ logits = outputs.logits
248
+
249
+ # Apply softmax to get probabilities
250
+ probabilities = torch.nn.functional.softmax(logits, dim=-1)
251
+
252
+ # Get predicted agent index (argmax)
253
+ predicted_idx = torch.argmax(probabilities, dim=-1).item()
254
+
255
+ # Extract confidence for selected agent
256
+ confidence = probabilities[0, predicted_idx].item()
257
+
258
+ # Create probability dictionary
259
+ prob_dict: dict[str, float] = {}
260
+ for i, agent_name in enumerate(self.AGENT_NAMES):
261
+ prob_dict[agent_name] = probabilities[0, i].item()
262
+
263
+ # Get agent name
264
+ selected_agent = self.AGENT_NAMES[predicted_idx]
265
+
266
+ return MetaControllerPrediction(
267
+ agent=selected_agent,
268
+ confidence=float(confidence),
269
+ probabilities=prob_dict,
270
+ )
271
+
272
+ def load_model(self, path: str) -> None:
273
+ """
274
+ Load a trained model from disk.
275
+
276
+ For LoRA models, loads the PEFT adapter weights. For base models,
277
+ loads the full state dictionary.
278
+
279
+ Args:
280
+ path: Path to the saved model file or directory.
281
+ For LoRA models, this should be a directory containing
282
+ adapter_config.json and adapter_model.bin.
283
+ For base models, this should be a .pt or .pth file.
284
+
285
+ Raises:
286
+ FileNotFoundError: If the model file or directory does not exist.
287
+ RuntimeError: If the state dict is incompatible with the model.
288
+
289
+ Example:
290
+ >>> controller = BERTMetaController(use_lora=True)
291
+ >>> controller.load_model("/path/to/lora_adapter")
292
+ >>> controller = BERTMetaController(use_lora=False)
293
+ >>> controller.load_model("/path/to/model.pt")
294
+ """
295
+ if self.use_lora:
296
+ # Load PEFT adapter weights
297
+ # For PEFT models, the path should be a directory containing adapter files
298
+ from peft import PeftModel
299
+
300
+ # Get the base model from the PEFT wrapper
301
+ base_model = self.model.get_base_model()
302
+
303
+ # Load the PEFT model from the saved path
304
+ self.model = PeftModel.from_pretrained(base_model, path)
305
+ self.model = self.model.to(self.device)
306
+ else:
307
+ # Load base model state dict
308
+ state_dict = torch.load(path, map_location=self.device, weights_only=True)
309
+ self.model.load_state_dict(state_dict)
310
+
311
+ # Ensure model is in evaluation mode
312
+ self.model.eval()
313
+
314
+ def save_model(self, path: str) -> None:
315
+ """
316
+ Save the current model to disk.
317
+
318
+ For LoRA models, saves the PEFT adapter weights. For base models,
319
+ saves the full state dictionary.
320
+
321
+ Args:
322
+ path: Path where the model should be saved.
323
+ For LoRA models, this should be a directory path where
324
+ adapter_config.json and adapter_model.bin will be saved.
325
+ For base models, this should be a .pt or .pth file path.
326
+
327
+ Example:
328
+ >>> controller = BERTMetaController(use_lora=True)
329
+ >>> controller.save_model("/path/to/lora_adapter")
330
+ >>> controller = BERTMetaController(use_lora=False)
331
+ >>> controller.save_model("/path/to/model.pt")
332
+ """
333
+ if self.use_lora:
334
+ # Save PEFT adapter weights
335
+ # This saves only the LoRA adapter weights, not the full model
336
+ self.model.save_pretrained(path)
337
+ else:
338
+ # Save base model state dict
339
+ torch.save(self.model.state_dict(), path)
340
+
341
+ def clear_cache(self) -> None:
342
+ """
343
+ Clear the tokenization cache.
344
+
345
+ This method removes all cached tokenized inputs, freeing memory.
346
+ Useful when processing many different feature combinations or
347
+ when memory usage is a concern.
348
+
349
+ Example:
350
+ >>> controller = BERTMetaController()
351
+ >>> # After many predictions...
352
+ >>> controller.clear_cache()
353
+ >>> info = controller.get_cache_info()
354
+ >>> info['cache_size'] == 0
355
+ True
356
+ """
357
+ self._tokenization_cache.clear()
358
+
359
+ def get_cache_info(self) -> dict[str, Any]:
360
+ """
361
+ Get information about the current tokenization cache.
362
+
363
+ Returns:
364
+ Dictionary containing cache statistics:
365
+ - cache_size: Number of cached tokenizations
366
+ - cache_keys: List of cached text inputs (truncated for display)
367
+
368
+ Example:
369
+ >>> controller = BERTMetaController()
370
+ >>> features = MetaControllerFeatures(
371
+ ... hrm_confidence=0.8,
372
+ ... trm_confidence=0.6,
373
+ ... mcts_value=0.75,
374
+ ... consensus_score=0.7,
375
+ ... last_agent='hrm',
376
+ ... iteration=2,
377
+ ... query_length=150,
378
+ ... has_rag_context=True
379
+ ... )
380
+ >>> _ = controller.predict(features)
381
+ >>> info = controller.get_cache_info()
382
+ >>> 'cache_size' in info
383
+ True
384
+ >>> info['cache_size'] >= 1
385
+ True
386
+ """
387
+ # Truncate keys for display (first 50 chars)
388
+ truncated_keys = [key[:50] + "..." if len(key) > 50 else key for key in self._tokenization_cache]
389
+
390
+ return {
391
+ "cache_size": len(self._tokenization_cache),
392
+ "cache_keys": truncated_keys,
393
+ }
394
+
395
+ def get_trainable_parameters(self) -> dict[str, int]:
396
+ """
397
+ Get the number of trainable and total parameters in the model.
398
+
399
+ This is particularly useful for LoRA models to see the efficiency
400
+ gains from using low-rank adaptation.
401
+
402
+ Returns:
403
+ Dictionary containing:
404
+ - total_params: Total number of parameters in the model
405
+ - trainable_params: Number of trainable parameters
406
+ - trainable_percentage: Percentage of parameters that are trainable
407
+
408
+ Example:
409
+ >>> controller = BERTMetaController(use_lora=True)
410
+ >>> params = controller.get_trainable_parameters()
411
+ >>> params['trainable_percentage'] < 10.0 # LoRA trains <10% of params
412
+ True
413
+ """
414
+ total_params = sum(p.numel() for p in self.model.parameters())
415
+ trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
416
+ trainable_percentage = (trainable_params / total_params) * 100 if total_params > 0 else 0.0
417
+
418
+ return {
419
+ "total_params": total_params,
420
+ "trainable_params": trainable_params,
421
+ "trainable_percentage": round(trainable_percentage, 2),
422
+ }
423
+