pythonprincess commited on
Commit
22fff45
·
verified ·
1 Parent(s): 6fcb9f3

Delete app/model_loader.py

Browse files
Files changed (1) hide show
  1. app/model_loader.py +0 -889
app/model_loader.py DELETED
@@ -1,889 +0,0 @@
1
- # app/model_loader.py
2
- """
3
- 🧠 PENNY Model Loader - Azure-Ready Multi-Model Orchestration
4
-
5
- This is Penny's brain loader. She manages multiple specialized models:
6
- - Gemma 7B for conversational reasoning
7
- - NLLB-200 for 27-language translation
8
- - Sentiment analysis for resident wellbeing
9
- - Bias detection for equitable service
10
- - LayoutLM for civic document processing
11
-
12
- MISSION: Load AI models efficiently in memory-constrained environments while
13
- maintaining Penny's warm, civic-focused personality across all interactions.
14
-
15
- FEATURES:
16
- - Lazy loading (models only load when needed)
17
- - 8-bit quantization for memory efficiency
18
- - GPU/CPU auto-detection
19
- - Model caching and reuse
20
- - Graceful fallbacks for Azure ML deployment
21
- - Memory monitoring and cleanup
22
- """
23
-
24
- import json
25
- import os
26
- import torch
27
- from typing import Dict, Any, Callable, Optional, Union, List
28
- from pathlib import Path
29
- import logging
30
- from dataclasses import dataclass
31
- from enum import Enum
32
- from datetime import datetime
33
-
34
- # --- LOGGING SETUP (Must be before functions that use it) ---
35
- logger = logging.getLogger(__name__)
36
-
37
- # ============================================================
38
- # HUGGING FACE AUTHENTICATION
39
- # ============================================================
40
-
41
- def setup_huggingface_auth() -> bool:
42
- """
43
- 🔐 Authenticates with Hugging Face Hub using HF_TOKEN.
44
-
45
- Returns:
46
- True if authentication successful or not needed, False if failed
47
- """
48
- HF_TOKEN = os.getenv("HF_TOKEN")
49
-
50
- if not HF_TOKEN:
51
- logger.warning("⚠️ HF_TOKEN not found in environment")
52
- logger.warning(" Some models may not be accessible")
53
- logger.warning(" Set HF_TOKEN in your environment or Hugging Face Spaces secrets")
54
- return False
55
-
56
- try:
57
- from huggingface_hub import login
58
- login(token=HF_TOKEN, add_to_git_credential=False)
59
- logger.info("✅ Authenticated with Hugging Face Hub")
60
- return True
61
- except ImportError:
62
- logger.warning("⚠️ huggingface_hub not installed, skipping authentication")
63
- return False
64
- except Exception as e:
65
- logger.error(f"❌ Failed to authenticate with Hugging Face: {e}")
66
- return False
67
-
68
- # Attempt authentication at module load
69
- setup_huggingface_auth()
70
-
71
- # --- PATH CONFIGURATION (Environment-Aware) ---
72
- # Support both local development and Azure ML deployment
73
- if os.getenv("AZUREML_MODEL_DIR"):
74
- # Azure ML deployment - models are in AZUREML_MODEL_DIR
75
- MODEL_ROOT = Path(os.getenv("AZUREML_MODEL_DIR"))
76
- CONFIG_PATH = MODEL_ROOT / "model_config.json"
77
- logger.info("☁️ Running in Azure ML environment")
78
- else:
79
- # Local development - models are in project structure
80
- PROJECT_ROOT = Path(__file__).parent.parent
81
- MODEL_ROOT = PROJECT_ROOT / "models"
82
- CONFIG_PATH = MODEL_ROOT / "model_config.json"
83
- logger.info("💻 Running in local development environment")
84
-
85
- logger.info(f"📂 Model config path: {CONFIG_PATH}")
86
-
87
- # ============================================================
88
- # PENNY'S CIVIC IDENTITY & PERSONALITY
89
- # ============================================================
90
-
91
- PENNY_SYSTEM_PROMPT = (
92
- "You are Penny, a smart, civic-focused AI assistant serving local communities. "
93
- "You help residents navigate city services, government programs, and community resources. "
94
- "You're warm, professional, accurate, and always stay within your civic mission.\n\n"
95
-
96
- "Your expertise includes:\n"
97
- "- Connecting people with local services (food banks, shelters, libraries)\n"
98
- "- Translating information into 27 languages\n"
99
- "- Explaining public programs and eligibility\n"
100
- "- Guiding residents through civic processes\n"
101
- "- Providing emergency resources when needed\n\n"
102
-
103
- "YOUR PERSONALITY:\n"
104
- "- Warm and approachable, like a helpful community center staff member\n"
105
- "- Clear and practical, avoiding jargon\n"
106
- "- Culturally sensitive and inclusive\n"
107
- "- Patient with repetition or clarification\n"
108
- "- Funny when appropriate, but never at anyone's expense\n\n"
109
-
110
- "CRITICAL RULES:\n"
111
- "- When residents greet you by name (e.g., 'Hi Penny'), respond warmly and personally\n"
112
- "- You are ALWAYS Penny - never ChatGPT, Assistant, Claude, or any other name\n"
113
- "- If you don't know something, say so clearly and help find the right resource\n"
114
- "- NEVER make up information about services, eligibility, or contacts\n"
115
- "- Stay within your civic mission - you don't provide legal, medical, or financial advice\n"
116
- "- For emergencies, immediately connect to appropriate services (911, crisis lines)\n\n"
117
- )
118
-
119
- # --- GLOBAL STATE ---
120
- _MODEL_CACHE: Dict[str, Any] = {} # Memory-efficient model reuse
121
- _LOAD_TIMES: Dict[str, float] = {} # Track model loading performance
122
-
123
-
124
- # ============================================================
125
- # DEVICE MANAGEMENT
126
- # ============================================================
127
-
128
- class DeviceType(str, Enum):
129
- """Supported compute devices."""
130
- CUDA = "cuda"
131
- CPU = "cpu"
132
- MPS = "mps" # Apple Silicon
133
-
134
-
135
- def get_optimal_device() -> str:
136
- """
137
- 🎮 Determines the best device for model inference.
138
-
139
- Priority:
140
- 1. CUDA GPU (NVIDIA)
141
- 2. MPS (Apple Silicon)
142
- 3. CPU (fallback)
143
-
144
- Returns:
145
- Device string ("cuda", "mps", or "cpu")
146
- """
147
- if torch.cuda.is_available():
148
- device = DeviceType.CUDA.value
149
- gpu_name = torch.cuda.get_device_name(0)
150
- gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
151
- logger.info(f"🎮 GPU detected: {gpu_name} ({gpu_memory:.1f}GB)")
152
- return device
153
-
154
- elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
155
- device = DeviceType.MPS.value
156
- logger.info("🍎 Apple Silicon (MPS) detected")
157
- return device
158
-
159
- else:
160
- device = DeviceType.CPU.value
161
- logger.info("💻 Using CPU for inference")
162
- logger.warning("⚠️ GPU not available - inference will be slower")
163
- return device
164
-
165
-
166
- def get_memory_stats() -> Dict[str, float]:
167
- """
168
- 📊 Returns current GPU/CPU memory statistics.
169
-
170
- Returns:
171
- Dict with memory stats in GB
172
- """
173
- stats = {}
174
-
175
- if torch.cuda.is_available():
176
- stats["gpu_allocated_gb"] = torch.cuda.memory_allocated() / 1e9
177
- stats["gpu_reserved_gb"] = torch.cuda.memory_reserved() / 1e9
178
- stats["gpu_total_gb"] = torch.cuda.get_device_properties(0).total_memory / 1e9
179
-
180
- # CPU memory (requires psutil)
181
- try:
182
- import psutil
183
- mem = psutil.virtual_memory()
184
- stats["cpu_used_gb"] = mem.used / 1e9
185
- stats["cpu_total_gb"] = mem.total / 1e9
186
- stats["cpu_percent"] = mem.percent
187
- except ImportError:
188
- pass
189
-
190
- return stats
191
-
192
-
193
- # ============================================================
194
- # MODEL CLIENT (Individual Model Handler)
195
- # ============================================================
196
-
197
- @dataclass
198
- class ModelMetadata:
199
- """
200
- 📋 Metadata about a loaded model.
201
- Tracks performance and resource usage.
202
- """
203
- name: str
204
- task: str
205
- model_name: str
206
- device: str
207
- loaded_at: Optional[datetime] = None
208
- load_time_seconds: Optional[float] = None
209
- memory_usage_gb: Optional[float] = None
210
- inference_count: int = 0
211
- total_inference_time_ms: float = 0.0
212
-
213
- @property
214
- def avg_inference_time_ms(self) -> float:
215
- """Calculate average inference time."""
216
- if self.inference_count == 0:
217
- return 0.0
218
- return self.total_inference_time_ms / self.inference_count
219
-
220
-
221
- class ModelClient:
222
- """
223
- 🤖 Manages a single HuggingFace model with optimized loading and inference.
224
-
225
- Features:
226
- - Lazy loading (load on first use)
227
- - Memory optimization (8-bit quantization)
228
- - Performance tracking
229
- - Graceful error handling
230
- - Automatic device placement
231
- """
232
-
233
- def __init__(
234
- self,
235
- name: str,
236
- model_name: str,
237
- task: str,
238
- device: str = None,
239
- config: Optional[Dict[str, Any]] = None
240
- ):
241
- """
242
- Initialize model client (doesn't load the model yet).
243
-
244
- Args:
245
- name: Model identifier (e.g., "penny-core-agent")
246
- model_name: HuggingFace model ID
247
- task: Task type (text-generation, translation, etc.)
248
- device: Target device (auto-detected if None)
249
- config: Additional model configuration
250
- """
251
- self.name = name
252
- self.model_name = model_name
253
- self.task = task
254
- self.device = device or get_optimal_device()
255
- self.config = config or {}
256
- self.pipeline = None
257
- self._load_attempted = False
258
- self.metadata = ModelMetadata(
259
- name=name,
260
- task=task,
261
- model_name=model_name,
262
- device=self.device
263
- )
264
-
265
- logger.info(f"📦 Initialized ModelClient: {name}")
266
- logger.debug(f" Model: {model_name}")
267
- logger.debug(f" Task: {task}")
268
- logger.debug(f" Device: {self.device}")
269
-
270
- def load_pipeline(self) -> bool:
271
- """
272
- 🔄 Loads the HuggingFace pipeline with Azure-optimized settings.
273
-
274
- Features:
275
- - 8-bit quantization for large models (saves ~50% memory)
276
- - Automatic device placement
277
- - Memory monitoring
278
- - Cache checking
279
-
280
- Returns:
281
- True if successful, False otherwise
282
- """
283
- if self.pipeline is not None:
284
- logger.debug(f"✅ {self.name} already loaded")
285
- return True
286
-
287
- if self._load_attempted:
288
- logger.warning(f"⚠️ Previous load attempt failed for {self.name}")
289
- return False
290
-
291
- global _MODEL_CACHE, _LOAD_TIMES
292
-
293
- # Check cache first
294
- if self.name in _MODEL_CACHE:
295
- logger.info(f"♻️ Using cached pipeline for {self.name}")
296
- self.pipeline = _MODEL_CACHE[self.name]
297
- return True
298
-
299
- logger.info(f"🔄 Loading {self.name} from HuggingFace...")
300
- self._load_attempted = True
301
-
302
- start_time = datetime.now()
303
-
304
- try:
305
- # Import pipeline from transformers (lazy import to avoid dependency issues)
306
- from transformers import pipeline
307
-
308
- # === TEXT GENERATION (Gemma 7B, GPT-2, etc.) ===
309
- if self.task == "text-generation":
310
- logger.info(" Using 8-bit quantization for memory efficiency...")
311
-
312
- # Check if model supports 8-bit loading
313
- use_8bit = self.device == DeviceType.CUDA.value
314
-
315
- if use_8bit:
316
- self.pipeline = pipeline(
317
- "text-generation",
318
- model=self.model_name,
319
- tokenizer=self.model_name,
320
- device_map="auto",
321
- load_in_8bit=True, # Reduces ~14GB to ~7GB
322
- trust_remote_code=True,
323
- torch_dtype=torch.float16
324
- )
325
- else:
326
- # CPU fallback
327
- self.pipeline = pipeline(
328
- "text-generation",
329
- model=self.model_name,
330
- tokenizer=self.model_name,
331
- device=-1, # CPU
332
- trust_remote_code=True,
333
- torch_dtype=torch.float32
334
- )
335
-
336
- # === TRANSLATION (NLLB-200, M2M-100, etc.) ===
337
- elif self.task == "translation":
338
- self.pipeline = pipeline(
339
- "translation",
340
- model=self.model_name,
341
- device=0 if self.device == DeviceType.CUDA.value else -1,
342
- src_lang=self.config.get("default_src_lang", "eng_Latn"),
343
- tgt_lang=self.config.get("default_tgt_lang", "spa_Latn")
344
- )
345
-
346
- # === SENTIMENT ANALYSIS ===
347
- elif self.task == "sentiment-analysis":
348
- self.pipeline = pipeline(
349
- "sentiment-analysis",
350
- model=self.model_name,
351
- device=0 if self.device == DeviceType.CUDA.value else -1,
352
- truncation=True,
353
- max_length=512
354
- )
355
-
356
- # === BIAS DETECTION (Zero-Shot Classification) ===
357
- elif self.task == "bias-detection":
358
- self.pipeline = pipeline(
359
- "zero-shot-classification",
360
- model=self.model_name,
361
- device=0 if self.device == DeviceType.CUDA.value else -1
362
- )
363
-
364
- # === TEXT CLASSIFICATION (Generic) ===
365
- elif self.task == "text-classification":
366
- self.pipeline = pipeline(
367
- "text-classification",
368
- model=self.model_name,
369
- device=0 if self.device == DeviceType.CUDA.value else -1,
370
- truncation=True
371
- )
372
-
373
- # === PDF/DOCUMENT EXTRACTION (LayoutLMv3) ===
374
- elif self.task == "pdf-extraction":
375
- logger.warning("⚠️ PDF extraction requires additional OCR setup")
376
- logger.info(" Consider using Azure Form Recognizer as alternative")
377
- # Placeholder - requires pytesseract/OCR infrastructure
378
- self.pipeline = None
379
- return False
380
-
381
- else:
382
- raise ValueError(f"Unknown task type: {self.task}")
383
-
384
- # === SUCCESS HANDLING ===
385
- if self.pipeline is not None:
386
- # Calculate load time
387
- load_time = (datetime.now() - start_time).total_seconds()
388
- self.metadata.loaded_at = datetime.now()
389
- self.metadata.load_time_seconds = load_time
390
-
391
- # Cache the pipeline
392
- _MODEL_CACHE[self.name] = self.pipeline
393
- _LOAD_TIMES[self.name] = load_time
394
-
395
- # Log memory usage
396
- mem_stats = get_memory_stats()
397
- self.metadata.memory_usage_gb = mem_stats.get("gpu_allocated_gb", 0)
398
-
399
- logger.info(f"✅ {self.name} loaded successfully!")
400
- logger.info(f" Load time: {load_time:.2f}s")
401
-
402
- if "gpu_allocated_gb" in mem_stats:
403
- logger.info(
404
- f" GPU Memory: {mem_stats['gpu_allocated_gb']:.2f}GB / "
405
- f"{mem_stats['gpu_total_gb']:.2f}GB"
406
- )
407
-
408
- return True
409
-
410
- except Exception as e:
411
- logger.error(f"❌ Failed to load {self.name}: {e}", exc_info=True)
412
- self.pipeline = None
413
- return False
414
-
415
- def predict(
416
- self,
417
- input_data: Union[str, Dict[str, Any]],
418
- **kwargs
419
- ) -> Dict[str, Any]:
420
- """
421
- 🎯 Runs inference with the loaded model pipeline.
422
-
423
- Features:
424
- - Automatic pipeline loading
425
- - Error handling with fallback responses
426
- - Performance tracking
427
- - Penny's personality injection (for text-generation)
428
-
429
- Args:
430
- input_data: Text or structured input for the model
431
- **kwargs: Task-specific parameters
432
-
433
- Returns:
434
- Model output dict with results or error information
435
- """
436
- # Track inference start time
437
- start_time = datetime.now()
438
-
439
- # Ensure pipeline is loaded
440
- if self.pipeline is None:
441
- success = self.load_pipeline()
442
- if not success:
443
- return {
444
- "error": f"{self.name} pipeline unavailable",
445
- "detail": "Model failed to load. Check logs for details.",
446
- "model": self.name
447
- }
448
-
449
- try:
450
- # === TEXT GENERATION ===
451
- if self.task == "text-generation":
452
- # Inject Penny's civic identity
453
- if not kwargs.get("skip_system_prompt", False):
454
- full_prompt = PENNY_SYSTEM_PROMPT + input_data
455
- else:
456
- full_prompt = input_data
457
-
458
- # Extract generation parameters with safe defaults
459
- max_new_tokens = kwargs.get("max_new_tokens", 256)
460
- temperature = kwargs.get("temperature", 0.7)
461
- top_p = kwargs.get("top_p", 0.9)
462
- do_sample = kwargs.get("do_sample", temperature > 0.0)
463
-
464
- result = self.pipeline(
465
- full_prompt,
466
- max_new_tokens=max_new_tokens,
467
- temperature=temperature,
468
- top_p=top_p,
469
- do_sample=do_sample,
470
- return_full_text=False,
471
- pad_token_id=self.pipeline.tokenizer.eos_token_id,
472
- truncation=True
473
- )
474
-
475
- output = {
476
- "generated_text": result[0]["generated_text"],
477
- "model": self.name,
478
- "success": True
479
- }
480
-
481
- # === TRANSLATION ===
482
- elif self.task == "translation":
483
- src_lang = kwargs.get("source_lang", "eng_Latn")
484
- tgt_lang = kwargs.get("target_lang", "spa_Latn")
485
-
486
- result = self.pipeline(
487
- input_data,
488
- src_lang=src_lang,
489
- tgt_lang=tgt_lang,
490
- max_length=512
491
- )
492
-
493
- output = {
494
- "translation": result[0]["translation_text"],
495
- "source_lang": src_lang,
496
- "target_lang": tgt_lang,
497
- "model": self.name,
498
- "success": True
499
- }
500
-
501
- # === SENTIMENT ANALYSIS ===
502
- elif self.task == "sentiment-analysis":
503
- result = self.pipeline(input_data)
504
-
505
- output = {
506
- "sentiment": result[0]["label"],
507
- "confidence": result[0]["score"],
508
- "model": self.name,
509
- "success": True
510
- }
511
-
512
- # === BIAS DETECTION ===
513
- elif self.task == "bias-detection":
514
- candidate_labels = kwargs.get("candidate_labels", [
515
- "neutral and objective",
516
- "contains political bias",
517
- "uses emotional language",
518
- "culturally insensitive"
519
- ])
520
-
521
- result = self.pipeline(
522
- input_data,
523
- candidate_labels=candidate_labels,
524
- multi_label=True
525
- )
526
-
527
- output = {
528
- "labels": result["labels"],
529
- "scores": result["scores"],
530
- "model": self.name,
531
- "success": True
532
- }
533
-
534
- # === TEXT CLASSIFICATION ===
535
- elif self.task == "text-classification":
536
- result = self.pipeline(input_data)
537
-
538
- output = {
539
- "label": result[0]["label"],
540
- "confidence": result[0]["score"],
541
- "model": self.name,
542
- "success": True
543
- }
544
-
545
- else:
546
- output = {
547
- "error": f"Task '{self.task}' not implemented",
548
- "model": self.name,
549
- "success": False
550
- }
551
-
552
- # Track performance
553
- inference_time = (datetime.now() - start_time).total_seconds() * 1000
554
- self.metadata.inference_count += 1
555
- self.metadata.total_inference_time_ms += inference_time
556
- output["inference_time_ms"] = round(inference_time, 2)
557
-
558
- return output
559
-
560
- except Exception as e:
561
- logger.error(f"❌ Inference error in {self.name}: {e}", exc_info=True)
562
- return {
563
- "error": "Inference failed",
564
- "detail": str(e),
565
- "model": self.name,
566
- "success": False
567
- }
568
-
569
- def unload(self) -> None:
570
- """
571
- 🗑️ Unloads the model to free memory.
572
- Critical for Azure environments with limited resources.
573
- """
574
- if self.pipeline is not None:
575
- logger.info(f"🗑️ Unloading {self.name}...")
576
-
577
- # Delete pipeline
578
- del self.pipeline
579
- self.pipeline = None
580
-
581
- # Remove from cache
582
- if self.name in _MODEL_CACHE:
583
- del _MODEL_CACHE[self.name]
584
-
585
- # Force GPU memory release
586
- if torch.cuda.is_available():
587
- torch.cuda.empty_cache()
588
-
589
- logger.info(f"✅ {self.name} unloaded successfully")
590
-
591
- # Log memory stats after unload
592
- mem_stats = get_memory_stats()
593
- if "gpu_allocated_gb" in mem_stats:
594
- logger.info(f" GPU Memory: {mem_stats['gpu_allocated_gb']:.2f}GB remaining")
595
-
596
- def get_metadata(self) -> Dict[str, Any]:
597
- """
598
- 📊 Returns model metadata and performance stats.
599
- """
600
- return {
601
- "name": self.metadata.name,
602
- "task": self.metadata.task,
603
- "model_name": self.metadata.model_name,
604
- "device": self.metadata.device,
605
- "loaded": self.pipeline is not None,
606
- "loaded_at": self.metadata.loaded_at.isoformat() if self.metadata.loaded_at else None,
607
- "load_time_seconds": self.metadata.load_time_seconds,
608
- "memory_usage_gb": self.metadata.memory_usage_gb,
609
- "inference_count": self.metadata.inference_count,
610
- "avg_inference_time_ms": round(self.metadata.avg_inference_time_ms, 2)
611
- }
612
-
613
-
614
- # ============================================================
615
- # MODEL LOADER (Singleton Manager)
616
- # ============================================================
617
-
618
- class ModelLoader:
619
- """
620
- 🎛️ Singleton manager for all Penny's specialized models.
621
-
622
- Features:
623
- - Centralized model configuration
624
- - Lazy loading (models only load when needed)
625
- - Memory management
626
- - Health monitoring
627
- - Unified access interface
628
- """
629
-
630
- _instance: Optional['ModelLoader'] = None
631
-
632
- def __new__(cls, *args, **kwargs):
633
- """Singleton pattern - only one ModelLoader instance."""
634
- if cls._instance is None:
635
- cls._instance = super(ModelLoader, cls).__new__(cls)
636
- return cls._instance
637
-
638
- def __init__(self, config_path: Optional[str] = None):
639
- """
640
- Initialize ModelLoader (only runs once due to singleton).
641
-
642
- Args:
643
- config_path: Path to model_config.json (optional)
644
- """
645
- if not hasattr(self, '_models_loaded'):
646
- self.models: Dict[str, ModelClient] = {}
647
- self._models_loaded = True
648
- self._initialization_time = datetime.now()
649
-
650
- # Use provided path or default
651
- config_file = Path(config_path) if config_path else CONFIG_PATH
652
-
653
- try:
654
- logger.info(f"📖 Loading model configuration from {config_file}")
655
-
656
- if not config_file.exists():
657
- logger.warning(f"⚠️ Configuration file not found: {config_file}")
658
- logger.info(" Create model_config.json with your model definitions")
659
- return
660
-
661
- with open(config_file, "r") as f:
662
- config = json.load(f)
663
-
664
- # Initialize ModelClients (doesn't load models yet)
665
- for model_id, model_info in config.items():
666
- self.models[model_id] = ModelClient(
667
- name=model_id,
668
- model_name=model_info["model_name"],
669
- task=model_info["task"],
670
- config=model_info.get("config", {})
671
- )
672
-
673
- logger.info(f"✅ ModelLoader initialized with {len(self.models)} models:")
674
- for model_id in self.models.keys():
675
- logger.info(f" - {model_id}")
676
-
677
- except json.JSONDecodeError as e:
678
- logger.error(f"❌ Invalid JSON in model_config.json: {e}")
679
- except Exception as e:
680
- logger.error(f"❌ Failed to initialize ModelLoader: {e}", exc_info=True)
681
-
682
- def get(self, model_id: str) -> Optional[ModelClient]:
683
- """
684
- 🎯 Retrieves a configured ModelClient by ID.
685
-
686
- Args:
687
- model_id: Model identifier from config
688
-
689
- Returns:
690
- ModelClient instance or None if not found
691
- """
692
- return self.models.get(model_id)
693
-
694
- def list_models(self) -> List[str]:
695
- """📋 Returns list of all available model IDs."""
696
- return list(self.models.keys())
697
-
698
- def get_loaded_models(self) -> List[str]:
699
- """📋 Returns list of currently loaded model IDs."""
700
- return [
701
- model_id
702
- for model_id, client in self.models.items()
703
- if client.pipeline is not None
704
- ]
705
-
706
- def unload_all(self) -> None:
707
- """
708
- 🗑️ Unloads all models to free memory.
709
- Useful for Azure environments when switching workloads.
710
- """
711
- logger.info("🗑️ Unloading all models...")
712
- for model_client in self.models.values():
713
- model_client.unload()
714
- logger.info("✅ All models unloaded")
715
-
716
- def get_status(self) -> Dict[str, Any]:
717
- """
718
- 📊 Returns comprehensive status of all models.
719
- Useful for health checks and monitoring.
720
- """
721
- status = {
722
- "initialization_time": self._initialization_time.isoformat(),
723
- "total_models": len(self.models),
724
- "loaded_models": len(self.get_loaded_models()),
725
- "device": get_optimal_device(),
726
- "memory": get_memory_stats(),
727
- "models": {}
728
- }
729
-
730
- for model_id, client in self.models.items():
731
- status["models"][model_id] = client.get_metadata()
732
-
733
- return status
734
-
735
-
736
- # ============================================================
737
- # PUBLIC INTERFACE (Used by all *_utils.py modules)
738
- # ============================================================
739
-
740
- def load_model_pipeline(agent_name: str) -> Callable[..., Dict[str, Any]]:
741
- """
742
- 🚀 Loads a model client and returns its inference function.
743
-
744
- This is the main function used by other modules (translation_utils.py,
745
- sentiment_utils.py, etc.) to access Penny's models.
746
-
747
- Args:
748
- agent_name: Model ID from model_config.json
749
-
750
- Returns:
751
- Callable inference function
752
-
753
- Raises:
754
- ValueError: If agent_name not found in configuration
755
-
756
- Example:
757
- >>> translator = load_model_pipeline("penny-translate-agent")
758
- >>> result = translator("Hello world", target_lang="spa_Latn")
759
- """
760
- loader = ModelLoader()
761
- client = loader.get(agent_name)
762
-
763
- if client is None:
764
- available = loader.list_models()
765
- raise ValueError(
766
- f"Agent ID '{agent_name}' not found in model configuration. "
767
- f"Available models: {available}"
768
- )
769
-
770
- # Load the pipeline (lazy loading)
771
- client.load_pipeline()
772
-
773
- # Return a callable wrapper
774
- def inference_wrapper(input_data, **kwargs):
775
- return client.predict(input_data, **kwargs)
776
-
777
- return inference_wrapper
778
-
779
-
780
- # === CONVENIENCE FUNCTIONS ===
781
-
782
- def get_model_status() -> Dict[str, Any]:
783
- """
784
- 📊 Returns status of all configured models.
785
- Useful for health checks and monitoring endpoints.
786
- """
787
- loader = ModelLoader()
788
- return loader.get_status()
789
-
790
-
791
- def preload_models(model_ids: Optional[List[str]] = None) -> None:
792
- """
793
- 🚀 Preloads specified models during startup.
794
-
795
- Args:
796
- model_ids: List of model IDs to preload (None = all models)
797
- """
798
- loader = ModelLoader()
799
-
800
- if model_ids is None:
801
- model_ids = loader.list_models()
802
-
803
- logger.info(f"🚀 Preloading {len(model_ids)} models...")
804
-
805
- for model_id in model_ids:
806
- client = loader.get(model_id)
807
- if client:
808
- logger.info(f" Loading {model_id}...")
809
- client.load_pipeline()
810
-
811
- logger.info("✅ Model preloading complete")
812
-
813
-
814
- def initialize_model_system() -> bool:
815
- """
816
- 🏁 Initializes the model system.
817
- Should be called during app startup.
818
-
819
- Returns:
820
- True if initialization successful
821
- """
822
- logger.info("🧠 Initializing Penny's model system...")
823
-
824
- try:
825
- # Initialize singleton
826
- loader = ModelLoader()
827
-
828
- # Log device info
829
- device = get_optimal_device()
830
- mem_stats = get_memory_stats()
831
-
832
- logger.info(f"✅ Model system initialized")
833
- logger.info(f"🎮 Compute device: {device}")
834
-
835
- if "gpu_total_gb" in mem_stats:
836
- logger.info(
837
- f"💾 GPU Memory: {mem_stats['gpu_total_gb']:.1f}GB total"
838
- )
839
-
840
- logger.info(f"📦 {len(loader.models)} models configured")
841
-
842
- # Optional: Preload critical models
843
- # Uncomment to preload models at startup
844
- # preload_models(["penny-core-agent"])
845
-
846
- return True
847
-
848
- except Exception as e:
849
- logger.error(f"❌ Failed to initialize model system: {e}", exc_info=True)
850
- return False
851
-
852
-
853
- # ============================================================
854
- # CLI TESTING & DEBUGGING
855
- # ============================================================
856
-
857
- if __name__ == "__main__":
858
- """
859
- 🧪 Test script for model loading and inference.
860
- Run with: python -m app.model_loader
861
- """
862
- print("=" * 60)
863
- print("🧪 Testing Penny's Model System")
864
- print("=" * 60)
865
-
866
- # Initialize
867
- loader = ModelLoader()
868
- print(f"\n📋 Available models: {loader.list_models()}")
869
-
870
- # Get status
871
- status = get_model_status()
872
- print(f"\n📊 System status:")
873
- print(json.dumps(status, indent=2, default=str))
874
-
875
- # Test model loading (if models configured)
876
- if loader.models:
877
- test_model_id = list(loader.models.keys())[0]
878
- print(f"\n🧪 Testing model: {test_model_id}")
879
-
880
- client = loader.get(test_model_id)
881
- if client:
882
- print(f" Loading pipeline...")
883
- success = client.load_pipeline()
884
-
885
- if success:
886
- print(f" ✅ Model loaded successfully!")
887
- print(f" Metadata: {json.dumps(client.get_metadata(), indent=2, default=str)}")
888
- else:
889
- print(f" ❌ Model loading failed")