pythonprincess commited on
Commit
ef18d3e
·
verified ·
1 Parent(s): 278dad3

Delete app/model_loader.py

Browse files
Files changed (1) hide show
  1. app/model_loader.py +0 -897
app/model_loader.py DELETED
@@ -1,897 +0,0 @@
1
- # app/model_loader.py
2
- """
3
- 🧠 PENNY Model Loader - Azure-Ready Multi-Model Orchestration
4
-
5
- This is Penny's brain loader. She manages multiple specialized models:
6
- - Gemma 7B for conversational reasoning
7
- - NLLB-200 for 27-language translation
8
- - Sentiment analysis for resident wellbeing
9
- - Bias detection for equitable service
10
- - LayoutLM for civic document processing
11
-
12
- MISSION: Load AI models efficiently in memory-constrained environments while
13
- maintaining Penny's warm, civic-focused personality across all interactions.
14
-
15
- FEATURES:
16
- - Lazy loading (models only load when needed)
17
- - 8-bit quantization for memory efficiency
18
- - GPU/CPU auto-detection
19
- - Model caching and reuse
20
- - Graceful fallbacks for Azure ML deployment
21
- - Memory monitoring and cleanup
22
- """
23
-
24
- import json
25
- import os
26
- import torch
27
- from typing import Dict, Any, Callable, Optional, Union, List
28
- from pathlib import Path
29
- import logging
30
- from dataclasses import dataclass
31
- from enum import Enum
32
- from datetime import datetime
33
-
34
- # --- LOGGING SETUP (Must be before functions that use it) ---
35
- logger = logging.getLogger(__name__)
36
-
37
- # ============================================================
38
- # HUGGING FACE AUTHENTICATION
39
- # ============================================================
40
-
41
- def setup_huggingface_auth() -> bool:
42
- """
43
- 🔐 Authenticates with Hugging Face Hub using HF_TOKEN or READTOKEN.
44
-
45
- Returns:
46
- True if authentication successful or not needed, False if failed
47
- """
48
- # Check for HF_TOKEN first, then READTOKEN (for Hugging Face Spaces)
49
- HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("READTOKEN")
50
-
51
- if not HF_TOKEN:
52
- logger.warning("⚠️ HF_TOKEN/READTOKEN not found in environment")
53
- logger.warning(" Some models may not be accessible")
54
- logger.warning(" Set HF_TOKEN or READTOKEN in your environment or Hugging Face Spaces secrets")
55
- return False
56
-
57
- try:
58
- from huggingface_hub import login
59
- login(token=HF_TOKEN, add_to_git_credential=False)
60
- logger.info("✅ Authenticated with Hugging Face Hub")
61
- return True
62
- except ImportError:
63
- logger.warning("⚠️ huggingface_hub not installed, skipping authentication")
64
- return False
65
- except Exception as e:
66
- logger.error(f"❌ Failed to authenticate with Hugging Face: {e}")
67
- return False
68
-
69
- # Attempt authentication at module load
70
- # Note: This runs when the module is imported, so HF_TOKEN must be in environment
71
- # For Hugging Face Spaces: Set HF_TOKEN as a secret in Space settings
72
- # For local dev: Add HF_TOKEN to .env file or export it
73
- _authentication_result = setup_huggingface_auth()
74
- if _authentication_result:
75
- logger.info("🔐 Hugging Face authentication successful - gated models accessible")
76
- else:
77
- logger.warning("⚠️ Hugging Face authentication failed - only public models will work")
78
-
79
- # --- PATH CONFIGURATION (Environment-Aware) ---
80
- # Support both local development and Azure ML deployment
81
- if os.getenv("AZUREML_MODEL_DIR"):
82
- # Azure ML deployment - models are in AZUREML_MODEL_DIR
83
- MODEL_ROOT = Path(os.getenv("AZUREML_MODEL_DIR"))
84
- CONFIG_PATH = MODEL_ROOT / "model_config.json"
85
- logger.info("☁️ Running in Azure ML environment")
86
- else:
87
- # Local development - models are in project structure
88
- PROJECT_ROOT = Path(__file__).parent.parent
89
- MODEL_ROOT = PROJECT_ROOT / "models"
90
- CONFIG_PATH = MODEL_ROOT / "model_config.json"
91
- logger.info("💻 Running in local development environment")
92
-
93
- logger.info(f"📂 Model config path: {CONFIG_PATH}")
94
-
95
- # ============================================================
96
- # PENNY'S CIVIC IDENTITY & PERSONALITY
97
- # ============================================================
98
-
99
- PENNY_SYSTEM_PROMPT = (
100
- "You are Penny, a smart, civic-focused AI assistant serving local communities. "
101
- "You help residents navigate city services, government programs, and community resources. "
102
- "You're warm, professional, accurate, and always stay within your civic mission.\n\n"
103
-
104
- "Your expertise includes:\n"
105
- "- Connecting people with local services (food banks, shelters, libraries)\n"
106
- "- Translating information into 27 languages\n"
107
- "- Explaining public programs and eligibility\n"
108
- "- Guiding residents through civic processes\n"
109
- "- Providing emergency resources when needed\n\n"
110
-
111
- "YOUR PERSONALITY:\n"
112
- "- Warm and approachable, like a helpful community center staff member\n"
113
- "- Clear and practical, avoiding jargon\n"
114
- "- Culturally sensitive and inclusive\n"
115
- "- Patient with repetition or clarification\n"
116
- "- Funny when appropriate, but never at anyone's expense\n\n"
117
-
118
- "CRITICAL RULES:\n"
119
- "- When residents greet you by name (e.g., 'Hi Penny'), respond warmly and personally\n"
120
- "- You are ALWAYS Penny - never ChatGPT, Assistant, Claude, or any other name\n"
121
- "- If you don't know something, say so clearly and help find the right resource\n"
122
- "- NEVER make up information about services, eligibility, or contacts\n"
123
- "- Stay within your civic mission - you don't provide legal, medical, or financial advice\n"
124
- "- For emergencies, immediately connect to appropriate services (911, crisis lines)\n\n"
125
- )
126
-
127
- # --- GLOBAL STATE ---
128
- _MODEL_CACHE: Dict[str, Any] = {} # Memory-efficient model reuse
129
- _LOAD_TIMES: Dict[str, float] = {} # Track model loading performance
130
-
131
-
132
- # ============================================================
133
- # DEVICE MANAGEMENT
134
- # ============================================================
135
-
136
- class DeviceType(str, Enum):
137
- """Supported compute devices."""
138
- CUDA = "cuda"
139
- CPU = "cpu"
140
- MPS = "mps" # Apple Silicon
141
-
142
-
143
- def get_optimal_device() -> str:
144
- """
145
- 🎮 Determines the best device for model inference.
146
-
147
- Priority:
148
- 1. CUDA GPU (NVIDIA)
149
- 2. MPS (Apple Silicon)
150
- 3. CPU (fallback)
151
-
152
- Returns:
153
- Device string ("cuda", "mps", or "cpu")
154
- """
155
- if torch.cuda.is_available():
156
- device = DeviceType.CUDA.value
157
- gpu_name = torch.cuda.get_device_name(0)
158
- gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
159
- logger.info(f"🎮 GPU detected: {gpu_name} ({gpu_memory:.1f}GB)")
160
- return device
161
-
162
- elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
163
- device = DeviceType.MPS.value
164
- logger.info("🍎 Apple Silicon (MPS) detected")
165
- return device
166
-
167
- else:
168
- device = DeviceType.CPU.value
169
- logger.info("💻 Using CPU for inference")
170
- logger.warning("⚠️ GPU not available - inference will be slower")
171
- return device
172
-
173
-
174
- def get_memory_stats() -> Dict[str, float]:
175
- """
176
- 📊 Returns current GPU/CPU memory statistics.
177
-
178
- Returns:
179
- Dict with memory stats in GB
180
- """
181
- stats = {}
182
-
183
- if torch.cuda.is_available():
184
- stats["gpu_allocated_gb"] = torch.cuda.memory_allocated() / 1e9
185
- stats["gpu_reserved_gb"] = torch.cuda.memory_reserved() / 1e9
186
- stats["gpu_total_gb"] = torch.cuda.get_device_properties(0).total_memory / 1e9
187
-
188
- # CPU memory (requires psutil)
189
- try:
190
- import psutil
191
- mem = psutil.virtual_memory()
192
- stats["cpu_used_gb"] = mem.used / 1e9
193
- stats["cpu_total_gb"] = mem.total / 1e9
194
- stats["cpu_percent"] = mem.percent
195
- except ImportError:
196
- pass
197
-
198
- return stats
199
-
200
-
201
- # ============================================================
202
- # MODEL CLIENT (Individual Model Handler)
203
- # ============================================================
204
-
205
- @dataclass
206
- class ModelMetadata:
207
- """
208
- 📋 Metadata about a loaded model.
209
- Tracks performance and resource usage.
210
- """
211
- name: str
212
- task: str
213
- model_name: str
214
- device: str
215
- loaded_at: Optional[datetime] = None
216
- load_time_seconds: Optional[float] = None
217
- memory_usage_gb: Optional[float] = None
218
- inference_count: int = 0
219
- total_inference_time_ms: float = 0.0
220
-
221
- @property
222
- def avg_inference_time_ms(self) -> float:
223
- """Calculate average inference time."""
224
- if self.inference_count == 0:
225
- return 0.0
226
- return self.total_inference_time_ms / self.inference_count
227
-
228
-
229
- class ModelClient:
230
- """
231
- 🤖 Manages a single HuggingFace model with optimized loading and inference.
232
-
233
- Features:
234
- - Lazy loading (load on first use)
235
- - Memory optimization (8-bit quantization)
236
- - Performance tracking
237
- - Graceful error handling
238
- - Automatic device placement
239
- """
240
-
241
- def __init__(
242
- self,
243
- name: str,
244
- model_name: str,
245
- task: str,
246
- device: str = None,
247
- config: Optional[Dict[str, Any]] = None
248
- ):
249
- """
250
- Initialize model client (doesn't load the model yet).
251
-
252
- Args:
253
- name: Model identifier (e.g., "penny-core-agent")
254
- model_name: HuggingFace model ID
255
- task: Task type (text-generation, translation, etc.)
256
- device: Target device (auto-detected if None)
257
- config: Additional model configuration
258
- """
259
- self.name = name
260
- self.model_name = model_name
261
- self.task = task
262
- self.device = device or get_optimal_device()
263
- self.config = config or {}
264
- self.pipeline = None
265
- self._load_attempted = False
266
- self.metadata = ModelMetadata(
267
- name=name,
268
- task=task,
269
- model_name=model_name,
270
- device=self.device
271
- )
272
-
273
- logger.info(f"📦 Initialized ModelClient: {name}")
274
- logger.debug(f" Model: {model_name}")
275
- logger.debug(f" Task: {task}")
276
- logger.debug(f" Device: {self.device}")
277
-
278
- def load_pipeline(self) -> bool:
279
- """
280
- 🔄 Loads the HuggingFace pipeline with Azure-optimized settings.
281
-
282
- Features:
283
- - 8-bit quantization for large models (saves ~50% memory)
284
- - Automatic device placement
285
- - Memory monitoring
286
- - Cache checking
287
-
288
- Returns:
289
- True if successful, False otherwise
290
- """
291
- if self.pipeline is not None:
292
- logger.debug(f"✅ {self.name} already loaded")
293
- return True
294
-
295
- if self._load_attempted:
296
- logger.warning(f"⚠️ Previous load attempt failed for {self.name}")
297
- return False
298
-
299
- global _MODEL_CACHE, _LOAD_TIMES
300
-
301
- # Check cache first
302
- if self.name in _MODEL_CACHE:
303
- logger.info(f"♻️ Using cached pipeline for {self.name}")
304
- self.pipeline = _MODEL_CACHE[self.name]
305
- return True
306
-
307
- logger.info(f"🔄 Loading {self.name} from HuggingFace...")
308
- self._load_attempted = True
309
-
310
- start_time = datetime.now()
311
-
312
- try:
313
- # Import pipeline from transformers (lazy import to avoid dependency issues)
314
- from transformers import pipeline
315
-
316
- # === TEXT GENERATION (Gemma 7B, GPT-2, etc.) ===
317
- if self.task == "text-generation":
318
- logger.info(" Using 8-bit quantization for memory efficiency...")
319
-
320
- # Check if model supports 8-bit loading
321
- use_8bit = self.device == DeviceType.CUDA.value
322
-
323
- if use_8bit:
324
- self.pipeline = pipeline(
325
- "text-generation",
326
- model=self.model_name,
327
- tokenizer=self.model_name,
328
- device_map="auto",
329
- load_in_8bit=True, # Reduces ~14GB to ~7GB
330
- trust_remote_code=True,
331
- torch_dtype=torch.float16
332
- )
333
- else:
334
- # CPU fallback
335
- self.pipeline = pipeline(
336
- "text-generation",
337
- model=self.model_name,
338
- tokenizer=self.model_name,
339
- device=-1, # CPU
340
- trust_remote_code=True,
341
- torch_dtype=torch.float32
342
- )
343
-
344
- # === TRANSLATION (NLLB-200, M2M-100, etc.) ===
345
- elif self.task == "translation":
346
- self.pipeline = pipeline(
347
- "translation",
348
- model=self.model_name,
349
- device=0 if self.device == DeviceType.CUDA.value else -1,
350
- src_lang=self.config.get("default_src_lang", "eng_Latn"),
351
- tgt_lang=self.config.get("default_tgt_lang", "spa_Latn")
352
- )
353
-
354
- # === SENTIMENT ANALYSIS ===
355
- elif self.task == "sentiment-analysis":
356
- self.pipeline = pipeline(
357
- "sentiment-analysis",
358
- model=self.model_name,
359
- device=0 if self.device == DeviceType.CUDA.value else -1,
360
- truncation=True,
361
- max_length=512
362
- )
363
-
364
- # === BIAS DETECTION (Zero-Shot Classification) ===
365
- elif self.task == "bias-detection":
366
- self.pipeline = pipeline(
367
- "zero-shot-classification",
368
- model=self.model_name,
369
- device=0 if self.device == DeviceType.CUDA.value else -1
370
- )
371
-
372
- # === TEXT CLASSIFICATION (Generic) ===
373
- elif self.task == "text-classification":
374
- self.pipeline = pipeline(
375
- "text-classification",
376
- model=self.model_name,
377
- device=0 if self.device == DeviceType.CUDA.value else -1,
378
- truncation=True
379
- )
380
-
381
- # === PDF/DOCUMENT EXTRACTION (LayoutLMv3) ===
382
- elif self.task == "pdf-extraction":
383
- logger.warning("⚠️ PDF extraction requires additional OCR setup")
384
- logger.info(" Consider using Azure Form Recognizer as alternative")
385
- # Placeholder - requires pytesseract/OCR infrastructure
386
- self.pipeline = None
387
- return False
388
-
389
- else:
390
- raise ValueError(f"Unknown task type: {self.task}")
391
-
392
- # === SUCCESS HANDLING ===
393
- if self.pipeline is not None:
394
- # Calculate load time
395
- load_time = (datetime.now() - start_time).total_seconds()
396
- self.metadata.loaded_at = datetime.now()
397
- self.metadata.load_time_seconds = load_time
398
-
399
- # Cache the pipeline
400
- _MODEL_CACHE[self.name] = self.pipeline
401
- _LOAD_TIMES[self.name] = load_time
402
-
403
- # Log memory usage
404
- mem_stats = get_memory_stats()
405
- self.metadata.memory_usage_gb = mem_stats.get("gpu_allocated_gb", 0)
406
-
407
- logger.info(f"✅ {self.name} loaded successfully!")
408
- logger.info(f" Load time: {load_time:.2f}s")
409
-
410
- if "gpu_allocated_gb" in mem_stats:
411
- logger.info(
412
- f" GPU Memory: {mem_stats['gpu_allocated_gb']:.2f}GB / "
413
- f"{mem_stats['gpu_total_gb']:.2f}GB"
414
- )
415
-
416
- return True
417
-
418
- except Exception as e:
419
- logger.error(f"❌ Failed to load {self.name}: {e}", exc_info=True)
420
- self.pipeline = None
421
- return False
422
-
423
- def predict(
424
- self,
425
- input_data: Union[str, Dict[str, Any]],
426
- **kwargs
427
- ) -> Dict[str, Any]:
428
- """
429
- 🎯 Runs inference with the loaded model pipeline.
430
-
431
- Features:
432
- - Automatic pipeline loading
433
- - Error handling with fallback responses
434
- - Performance tracking
435
- - Penny's personality injection (for text-generation)
436
-
437
- Args:
438
- input_data: Text or structured input for the model
439
- **kwargs: Task-specific parameters
440
-
441
- Returns:
442
- Model output dict with results or error information
443
- """
444
- # Track inference start time
445
- start_time = datetime.now()
446
-
447
- # Ensure pipeline is loaded
448
- if self.pipeline is None:
449
- success = self.load_pipeline()
450
- if not success:
451
- return {
452
- "error": f"{self.name} pipeline unavailable",
453
- "detail": "Model failed to load. Check logs for details.",
454
- "model": self.name
455
- }
456
-
457
- try:
458
- # === TEXT GENERATION ===
459
- if self.task == "text-generation":
460
- # Inject Penny's civic identity
461
- if not kwargs.get("skip_system_prompt", False):
462
- full_prompt = PENNY_SYSTEM_PROMPT + input_data
463
- else:
464
- full_prompt = input_data
465
-
466
- # Extract generation parameters with safe defaults
467
- max_new_tokens = kwargs.get("max_new_tokens", 256)
468
- temperature = kwargs.get("temperature", 0.7)
469
- top_p = kwargs.get("top_p", 0.9)
470
- do_sample = kwargs.get("do_sample", temperature > 0.0)
471
-
472
- result = self.pipeline(
473
- full_prompt,
474
- max_new_tokens=max_new_tokens,
475
- temperature=temperature,
476
- top_p=top_p,
477
- do_sample=do_sample,
478
- return_full_text=False,
479
- pad_token_id=self.pipeline.tokenizer.eos_token_id,
480
- truncation=True
481
- )
482
-
483
- output = {
484
- "generated_text": result[0]["generated_text"],
485
- "model": self.name,
486
- "success": True
487
- }
488
-
489
- # === TRANSLATION ===
490
- elif self.task == "translation":
491
- src_lang = kwargs.get("source_lang", "eng_Latn")
492
- tgt_lang = kwargs.get("target_lang", "spa_Latn")
493
-
494
- result = self.pipeline(
495
- input_data,
496
- src_lang=src_lang,
497
- tgt_lang=tgt_lang,
498
- max_length=512
499
- )
500
-
501
- output = {
502
- "translation": result[0]["translation_text"],
503
- "source_lang": src_lang,
504
- "target_lang": tgt_lang,
505
- "model": self.name,
506
- "success": True
507
- }
508
-
509
- # === SENTIMENT ANALYSIS ===
510
- elif self.task == "sentiment-analysis":
511
- result = self.pipeline(input_data)
512
-
513
- output = {
514
- "sentiment": result[0]["label"],
515
- "confidence": result[0]["score"],
516
- "model": self.name,
517
- "success": True
518
- }
519
-
520
- # === BIAS DETECTION ===
521
- elif self.task == "bias-detection":
522
- candidate_labels = kwargs.get("candidate_labels", [
523
- "neutral and objective",
524
- "contains political bias",
525
- "uses emotional language",
526
- "culturally insensitive"
527
- ])
528
-
529
- result = self.pipeline(
530
- input_data,
531
- candidate_labels=candidate_labels,
532
- multi_label=True
533
- )
534
-
535
- output = {
536
- "labels": result["labels"],
537
- "scores": result["scores"],
538
- "model": self.name,
539
- "success": True
540
- }
541
-
542
- # === TEXT CLASSIFICATION ===
543
- elif self.task == "text-classification":
544
- result = self.pipeline(input_data)
545
-
546
- output = {
547
- "label": result[0]["label"],
548
- "confidence": result[0]["score"],
549
- "model": self.name,
550
- "success": True
551
- }
552
-
553
- else:
554
- output = {
555
- "error": f"Task '{self.task}' not implemented",
556
- "model": self.name,
557
- "success": False
558
- }
559
-
560
- # Track performance
561
- inference_time = (datetime.now() - start_time).total_seconds() * 1000
562
- self.metadata.inference_count += 1
563
- self.metadata.total_inference_time_ms += inference_time
564
- output["inference_time_ms"] = round(inference_time, 2)
565
-
566
- return output
567
-
568
- except Exception as e:
569
- logger.error(f"❌ Inference error in {self.name}: {e}", exc_info=True)
570
- return {
571
- "error": "Inference failed",
572
- "detail": str(e),
573
- "model": self.name,
574
- "success": False
575
- }
576
-
577
- def unload(self) -> None:
578
- """
579
- 🗑️ Unloads the model to free memory.
580
- Critical for Azure environments with limited resources.
581
- """
582
- if self.pipeline is not None:
583
- logger.info(f"🗑️ Unloading {self.name}...")
584
-
585
- # Delete pipeline
586
- del self.pipeline
587
- self.pipeline = None
588
-
589
- # Remove from cache
590
- if self.name in _MODEL_CACHE:
591
- del _MODEL_CACHE[self.name]
592
-
593
- # Force GPU memory release
594
- if torch.cuda.is_available():
595
- torch.cuda.empty_cache()
596
-
597
- logger.info(f"✅ {self.name} unloaded successfully")
598
-
599
- # Log memory stats after unload
600
- mem_stats = get_memory_stats()
601
- if "gpu_allocated_gb" in mem_stats:
602
- logger.info(f" GPU Memory: {mem_stats['gpu_allocated_gb']:.2f}GB remaining")
603
-
604
- def get_metadata(self) -> Dict[str, Any]:
605
- """
606
- 📊 Returns model metadata and performance stats.
607
- """
608
- return {
609
- "name": self.metadata.name,
610
- "task": self.metadata.task,
611
- "model_name": self.metadata.model_name,
612
- "device": self.metadata.device,
613
- "loaded": self.pipeline is not None,
614
- "loaded_at": self.metadata.loaded_at.isoformat() if self.metadata.loaded_at else None,
615
- "load_time_seconds": self.metadata.load_time_seconds,
616
- "memory_usage_gb": self.metadata.memory_usage_gb,
617
- "inference_count": self.metadata.inference_count,
618
- "avg_inference_time_ms": round(self.metadata.avg_inference_time_ms, 2)
619
- }
620
-
621
-
622
- # ============================================================
623
- # MODEL LOADER (Singleton Manager)
624
- # ============================================================
625
-
626
- class ModelLoader:
627
- """
628
- 🎛️ Singleton manager for all Penny's specialized models.
629
-
630
- Features:
631
- - Centralized model configuration
632
- - Lazy loading (models only load when needed)
633
- - Memory management
634
- - Health monitoring
635
- - Unified access interface
636
- """
637
-
638
- _instance: Optional['ModelLoader'] = None
639
-
640
- def __new__(cls, *args, **kwargs):
641
- """Singleton pattern - only one ModelLoader instance."""
642
- if cls._instance is None:
643
- cls._instance = super(ModelLoader, cls).__new__(cls)
644
- return cls._instance
645
-
646
- def __init__(self, config_path: Optional[str] = None):
647
- """
648
- Initialize ModelLoader (only runs once due to singleton).
649
-
650
- Args:
651
- config_path: Path to model_config.json (optional)
652
- """
653
- if not hasattr(self, '_models_loaded'):
654
- self.models: Dict[str, ModelClient] = {}
655
- self._models_loaded = True
656
- self._initialization_time = datetime.now()
657
-
658
- # Use provided path or default
659
- config_file = Path(config_path) if config_path else CONFIG_PATH
660
-
661
- try:
662
- logger.info(f"📖 Loading model configuration from {config_file}")
663
-
664
- if not config_file.exists():
665
- logger.warning(f"⚠️ Configuration file not found: {config_file}")
666
- logger.info(" Create model_config.json with your model definitions")
667
- return
668
-
669
- with open(config_file, "r") as f:
670
- config = json.load(f)
671
-
672
- # Initialize ModelClients (doesn't load models yet)
673
- for model_id, model_info in config.items():
674
- self.models[model_id] = ModelClient(
675
- name=model_id,
676
- model_name=model_info["model_name"],
677
- task=model_info["task"],
678
- config=model_info.get("config", {})
679
- )
680
-
681
- logger.info(f"✅ ModelLoader initialized with {len(self.models)} models:")
682
- for model_id in self.models.keys():
683
- logger.info(f" - {model_id}")
684
-
685
- except json.JSONDecodeError as e:
686
- logger.error(f"❌ Invalid JSON in model_config.json: {e}")
687
- except Exception as e:
688
- logger.error(f"❌ Failed to initialize ModelLoader: {e}", exc_info=True)
689
-
690
- def get(self, model_id: str) -> Optional[ModelClient]:
691
- """
692
- 🎯 Retrieves a configured ModelClient by ID.
693
-
694
- Args:
695
- model_id: Model identifier from config
696
-
697
- Returns:
698
- ModelClient instance or None if not found
699
- """
700
- return self.models.get(model_id)
701
-
702
- def list_models(self) -> List[str]:
703
- """📋 Returns list of all available model IDs."""
704
- return list(self.models.keys())
705
-
706
- def get_loaded_models(self) -> List[str]:
707
- """📋 Returns list of currently loaded model IDs."""
708
- return [
709
- model_id
710
- for model_id, client in self.models.items()
711
- if client.pipeline is not None
712
- ]
713
-
714
- def unload_all(self) -> None:
715
- """
716
- 🗑️ Unloads all models to free memory.
717
- Useful for Azure environments when switching workloads.
718
- """
719
- logger.info("🗑️ Unloading all models...")
720
- for model_client in self.models.values():
721
- model_client.unload()
722
- logger.info("✅ All models unloaded")
723
-
724
- def get_status(self) -> Dict[str, Any]:
725
- """
726
- 📊 Returns comprehensive status of all models.
727
- Useful for health checks and monitoring.
728
- """
729
- status = {
730
- "initialization_time": self._initialization_time.isoformat(),
731
- "total_models": len(self.models),
732
- "loaded_models": len(self.get_loaded_models()),
733
- "device": get_optimal_device(),
734
- "memory": get_memory_stats(),
735
- "models": {}
736
- }
737
-
738
- for model_id, client in self.models.items():
739
- status["models"][model_id] = client.get_metadata()
740
-
741
- return status
742
-
743
-
744
- # ============================================================
745
- # PUBLIC INTERFACE (Used by all *_utils.py modules)
746
- # ============================================================
747
-
748
- def load_model_pipeline(agent_name: str) -> Callable[..., Dict[str, Any]]:
749
- """
750
- 🚀 Loads a model client and returns its inference function.
751
-
752
- This is the main function used by other modules (translation_utils.py,
753
- sentiment_utils.py, etc.) to access Penny's models.
754
-
755
- Args:
756
- agent_name: Model ID from model_config.json
757
-
758
- Returns:
759
- Callable inference function
760
-
761
- Raises:
762
- ValueError: If agent_name not found in configuration
763
-
764
- Example:
765
- >>> translator = load_model_pipeline("penny-translate-agent")
766
- >>> result = translator("Hello world", target_lang="spa_Latn")
767
- """
768
- loader = ModelLoader()
769
- client = loader.get(agent_name)
770
-
771
- if client is None:
772
- available = loader.list_models()
773
- raise ValueError(
774
- f"Agent ID '{agent_name}' not found in model configuration. "
775
- f"Available models: {available}"
776
- )
777
-
778
- # Load the pipeline (lazy loading)
779
- client.load_pipeline()
780
-
781
- # Return a callable wrapper
782
- def inference_wrapper(input_data, **kwargs):
783
- return client.predict(input_data, **kwargs)
784
-
785
- return inference_wrapper
786
-
787
-
788
- # === CONVENIENCE FUNCTIONS ===
789
-
790
- def get_model_status() -> Dict[str, Any]:
791
- """
792
- 📊 Returns status of all configured models.
793
- Useful for health checks and monitoring endpoints.
794
- """
795
- loader = ModelLoader()
796
- return loader.get_status()
797
-
798
-
799
- def preload_models(model_ids: Optional[List[str]] = None) -> None:
800
- """
801
- 🚀 Preloads specified models during startup.
802
-
803
- Args:
804
- model_ids: List of model IDs to preload (None = all models)
805
- """
806
- loader = ModelLoader()
807
-
808
- if model_ids is None:
809
- model_ids = loader.list_models()
810
-
811
- logger.info(f"🚀 Preloading {len(model_ids)} models...")
812
-
813
- for model_id in model_ids:
814
- client = loader.get(model_id)
815
- if client:
816
- logger.info(f" Loading {model_id}...")
817
- client.load_pipeline()
818
-
819
- logger.info("✅ Model preloading complete")
820
-
821
-
822
- def initialize_model_system() -> bool:
823
- """
824
- 🏁 Initializes the model system.
825
- Should be called during app startup.
826
-
827
- Returns:
828
- True if initialization successful
829
- """
830
- logger.info("🧠 Initializing Penny's model system...")
831
-
832
- try:
833
- # Initialize singleton
834
- loader = ModelLoader()
835
-
836
- # Log device info
837
- device = get_optimal_device()
838
- mem_stats = get_memory_stats()
839
-
840
- logger.info(f"✅ Model system initialized")
841
- logger.info(f"🎮 Compute device: {device}")
842
-
843
- if "gpu_total_gb" in mem_stats:
844
- logger.info(
845
- f"💾 GPU Memory: {mem_stats['gpu_total_gb']:.1f}GB total"
846
- )
847
-
848
- logger.info(f"📦 {len(loader.models)} models configured")
849
-
850
- # Optional: Preload critical models
851
- # Uncomment to preload models at startup
852
- # preload_models(["penny-core-agent"])
853
-
854
- return True
855
-
856
- except Exception as e:
857
- logger.error(f"❌ Failed to initialize model system: {e}", exc_info=True)
858
- return False
859
-
860
-
861
- # ============================================================
862
- # CLI TESTING & DEBUGGING
863
- # ============================================================
864
-
865
- if __name__ == "__main__":
866
- """
867
- 🧪 Test script for model loading and inference.
868
- Run with: python -m app.model_loader
869
- """
870
- print("=" * 60)
871
- print("🧪 Testing Penny's Model System")
872
- print("=" * 60)
873
-
874
- # Initialize
875
- loader = ModelLoader()
876
- print(f"\n📋 Available models: {loader.list_models()}")
877
-
878
- # Get status
879
- status = get_model_status()
880
- print(f"\n📊 System status:")
881
- print(json.dumps(status, indent=2, default=str))
882
-
883
- # Test model loading (if models configured)
884
- if loader.models:
885
- test_model_id = list(loader.models.keys())[0]
886
- print(f"\n🧪 Testing model: {test_model_id}")
887
-
888
- client = loader.get(test_model_id)
889
- if client:
890
- print(f" Loading pipeline...")
891
- success = client.load_pipeline()
892
-
893
- if success:
894
- print(f" ✅ Model loaded successfully!")
895
- print(f" Metadata: {json.dumps(client.get_metadata(), indent=2, default=str)}")
896
- else:
897
- print(f" ❌ Model loading failed")