pythonprincess commited on
Commit
e785234
·
verified ·
1 Parent(s): f583ae5

Delete app/model_loader.py

Browse files
Files changed (1) hide show
  1. app/model_loader.py +0 -886
app/model_loader.py DELETED
@@ -1,886 +0,0 @@
1
- # app/model_loader.py
2
- """
3
- 🧠 PENNY Model Loader - Azure-Ready Multi-Model Orchestration
4
-
5
- This is Penny's brain loader. She manages multiple specialized models:
6
- - Gemma 7B for conversational reasoning
7
- - NLLB-200 for 27-language translation
8
- - Sentiment analysis for resident wellbeing
9
- - Bias detection for equitable service
10
- - LayoutLM for civic document processing
11
-
12
- MISSION: Load AI models efficiently in memory-constrained environments while
13
- maintaining Penny's warm, civic-focused personality across all interactions.
14
-
15
- FEATURES:
16
- - Lazy loading (models only load when needed)
17
- - 8-bit quantization for memory efficiency
18
- - GPU/CPU auto-detection
19
- - Model caching and reuse
20
- - Graceful fallbacks for Azure ML deployment
21
- - Memory monitoring and cleanup
22
- """
23
-
24
- import json
25
- import os
26
- import torch
27
- from typing import Dict, Any, Callable, Optional, Union, List
28
- from pathlib import Path
29
- import logging
30
- from dataclasses import dataclass
31
- from enum import Enum
32
- from datetime import datetime
33
-
34
- # ============================================================
35
- # HUGGING FACE AUTHENTICATION
36
- # ============================================================
37
-
38
- def setup_huggingface_auth() -> bool:
39
- """
40
- 🔐 Authenticates with Hugging Face Hub using HF_TOKEN.
41
-
42
- Returns:
43
- True if authentication successful or not needed, False if failed
44
- """
45
- HF_TOKEN = os.getenv("HF_TOKEN")
46
-
47
- if not HF_TOKEN:
48
- logger.warning("⚠️ HF_TOKEN not found in environment")
49
- logger.warning(" Some models may not be accessible")
50
- logger.warning(" Set HF_TOKEN in your environment or Hugging Face Spaces secrets")
51
- return False
52
-
53
- try:
54
- from huggingface_hub import login
55
- login(token=HF_TOKEN, add_to_git_credential=False)
56
- logger.info("✅ Authenticated with Hugging Face Hub")
57
- return True
58
- except ImportError:
59
- logger.warning("⚠️ huggingface_hub not installed, skipping authentication")
60
- return False
61
- except Exception as e:
62
- logger.error(f"❌ Failed to authenticate with Hugging Face: {e}")
63
- return False
64
-
65
- # Attempt authentication at module load
66
- setup_huggingface_auth()
67
-
68
- # --- LOGGING SETUP ---
69
- logger = logging.getLogger(__name__)
70
-
71
- # --- PATH CONFIGURATION (Environment-Aware) ---
72
- # Support both local development and Azure ML deployment
73
- if os.getenv("AZUREML_MODEL_DIR"):
74
- # Azure ML deployment - models are in AZUREML_MODEL_DIR
75
- MODEL_ROOT = Path(os.getenv("AZUREML_MODEL_DIR"))
76
- CONFIG_PATH = MODEL_ROOT / "model_config.json"
77
- logger.info("☁️ Running in Azure ML environment")
78
- else:
79
- # Local development - models are in project structure
80
- PROJECT_ROOT = Path(__file__).parent.parent
81
- MODEL_ROOT = PROJECT_ROOT / "models"
82
- CONFIG_PATH = MODEL_ROOT / "model_config.json"
83
- logger.info("💻 Running in local development environment")
84
-
85
- logger.info(f"📂 Model config path: {CONFIG_PATH}")
86
-
87
- # ============================================================
88
- # PENNY'S CIVIC IDENTITY & PERSONALITY
89
- # ============================================================
90
-
91
- PENNY_SYSTEM_PROMPT = (
92
- "You are Penny, a smart, civic-focused AI assistant serving local communities. "
93
- "You help residents navigate city services, government programs, and community resources. "
94
- "You're warm, professional, accurate, and always stay within your civic mission.\n\n"
95
-
96
- "Your expertise includes:\n"
97
- "- Connecting people with local services (food banks, shelters, libraries)\n"
98
- "- Translating information into 27 languages\n"
99
- "- Explaining public programs and eligibility\n"
100
- "- Guiding residents through civic processes\n"
101
- "- Providing emergency resources when needed\n\n"
102
-
103
- "YOUR PERSONALITY:\n"
104
- "- Warm and approachable, like a helpful community center staff member\n"
105
- "- Clear and practical, avoiding jargon\n"
106
- "- Culturally sensitive and inclusive\n"
107
- "- Patient with repetition or clarification\n"
108
- "- Funny when appropriate, but never at anyone's expense\n\n"
109
-
110
- "CRITICAL RULES:\n"
111
- "- When residents greet you by name (e.g., 'Hi Penny'), respond warmly and personally\n"
112
- "- You are ALWAYS Penny - never ChatGPT, Assistant, Claude, or any other name\n"
113
- "- If you don't know something, say so clearly and help find the right resource\n"
114
- "- NEVER make up information about services, eligibility, or contacts\n"
115
- "- Stay within your civic mission - you don't provide legal, medical, or financial advice\n"
116
- "- For emergencies, immediately connect to appropriate services (911, crisis lines)\n\n"
117
- )
118
-
119
- # --- GLOBAL STATE ---
120
- _MODEL_CACHE: Dict[str, Any] = {} # Memory-efficient model reuse
121
- _LOAD_TIMES: Dict[str, float] = {} # Track model loading performance
122
-
123
-
124
- # ============================================================
125
- # DEVICE MANAGEMENT
126
- # ============================================================
127
-
128
- class DeviceType(str, Enum):
129
- """Supported compute devices."""
130
- CUDA = "cuda"
131
- CPU = "cpu"
132
- MPS = "mps" # Apple Silicon
133
-
134
-
135
- def get_optimal_device() -> str:
136
- """
137
- 🎮 Determines the best device for model inference.
138
-
139
- Priority:
140
- 1. CUDA GPU (NVIDIA)
141
- 2. MPS (Apple Silicon)
142
- 3. CPU (fallback)
143
-
144
- Returns:
145
- Device string ("cuda", "mps", or "cpu")
146
- """
147
- if torch.cuda.is_available():
148
- device = DeviceType.CUDA.value
149
- gpu_name = torch.cuda.get_device_name(0)
150
- gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
151
- logger.info(f"🎮 GPU detected: {gpu_name} ({gpu_memory:.1f}GB)")
152
- return device
153
-
154
- elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
155
- device = DeviceType.MPS.value
156
- logger.info("🍎 Apple Silicon (MPS) detected")
157
- return device
158
-
159
- else:
160
- device = DeviceType.CPU.value
161
- logger.info("💻 Using CPU for inference")
162
- logger.warning("⚠️ GPU not available - inference will be slower")
163
- return device
164
-
165
-
166
- def get_memory_stats() -> Dict[str, float]:
167
- """
168
- 📊 Returns current GPU/CPU memory statistics.
169
-
170
- Returns:
171
- Dict with memory stats in GB
172
- """
173
- stats = {}
174
-
175
- if torch.cuda.is_available():
176
- stats["gpu_allocated_gb"] = torch.cuda.memory_allocated() / 1e9
177
- stats["gpu_reserved_gb"] = torch.cuda.memory_reserved() / 1e9
178
- stats["gpu_total_gb"] = torch.cuda.get_device_properties(0).total_memory / 1e9
179
-
180
- # CPU memory (requires psutil)
181
- try:
182
- import psutil
183
- mem = psutil.virtual_memory()
184
- stats["cpu_used_gb"] = mem.used / 1e9
185
- stats["cpu_total_gb"] = mem.total / 1e9
186
- stats["cpu_percent"] = mem.percent
187
- except ImportError:
188
- pass
189
-
190
- return stats
191
-
192
-
193
- # ============================================================
194
- # MODEL CLIENT (Individual Model Handler)
195
- # ============================================================
196
-
197
- @dataclass
198
- class ModelMetadata:
199
- """
200
- 📋 Metadata about a loaded model.
201
- Tracks performance and resource usage.
202
- """
203
- name: str
204
- task: str
205
- model_name: str
206
- device: str
207
- loaded_at: Optional[datetime] = None
208
- load_time_seconds: Optional[float] = None
209
- memory_usage_gb: Optional[float] = None
210
- inference_count: int = 0
211
- total_inference_time_ms: float = 0.0
212
-
213
- @property
214
- def avg_inference_time_ms(self) -> float:
215
- """Calculate average inference time."""
216
- if self.inference_count == 0:
217
- return 0.0
218
- return self.total_inference_time_ms / self.inference_count
219
-
220
-
221
- class ModelClient:
222
- """
223
- 🤖 Manages a single HuggingFace model with optimized loading and inference.
224
-
225
- Features:
226
- - Lazy loading (load on first use)
227
- - Memory optimization (8-bit quantization)
228
- - Performance tracking
229
- - Graceful error handling
230
- - Automatic device placement
231
- """
232
-
233
- def __init__(
234
- self,
235
- name: str,
236
- model_name: str,
237
- task: str,
238
- device: str = None,
239
- config: Optional[Dict[str, Any]] = None
240
- ):
241
- """
242
- Initialize model client (doesn't load the model yet).
243
-
244
- Args:
245
- name: Model identifier (e.g., "penny-core-agent")
246
- model_name: HuggingFace model ID
247
- task: Task type (text-generation, translation, etc.)
248
- device: Target device (auto-detected if None)
249
- config: Additional model configuration
250
- """
251
- self.name = name
252
- self.model_name = model_name
253
- self.task = task
254
- self.device = device or get_optimal_device()
255
- self.config = config or {}
256
- self.pipeline = None
257
- self._load_attempted = False
258
- self.metadata = ModelMetadata(
259
- name=name,
260
- task=task,
261
- model_name=model_name,
262
- device=self.device
263
- )
264
-
265
- logger.info(f"📦 Initialized ModelClient: {name}")
266
- logger.debug(f" Model: {model_name}")
267
- logger.debug(f" Task: {task}")
268
- logger.debug(f" Device: {self.device}")
269
-
270
- def load_pipeline(self) -> bool:
271
- """
272
- 🔄 Loads the HuggingFace pipeline with Azure-optimized settings.
273
-
274
- Features:
275
- - 8-bit quantization for large models (saves ~50% memory)
276
- - Automatic device placement
277
- - Memory monitoring
278
- - Cache checking
279
-
280
- Returns:
281
- True if successful, False otherwise
282
- """
283
- if self.pipeline is not None:
284
- logger.debug(f"✅ {self.name} already loaded")
285
- return True
286
-
287
- if self._load_attempted:
288
- logger.warning(f"⚠️ Previous load attempt failed for {self.name}")
289
- return False
290
-
291
- global _MODEL_CACHE, _LOAD_TIMES
292
-
293
- # Check cache first
294
- if self.name in _MODEL_CACHE:
295
- logger.info(f"♻️ Using cached pipeline for {self.name}")
296
- self.pipeline = _MODEL_CACHE[self.name]
297
- return True
298
-
299
- logger.info(f"🔄 Loading {self.name} from HuggingFace...")
300
- self._load_attempted = True
301
-
302
- start_time = datetime.now()
303
-
304
- try:
305
- # === TEXT GENERATION (Gemma 7B, GPT-2, etc.) ===
306
- if self.task == "text-generation":
307
- logger.info(" Using 8-bit quantization for memory efficiency...")
308
-
309
- # Check if model supports 8-bit loading
310
- use_8bit = self.device == DeviceType.CUDA.value
311
-
312
- if use_8bit:
313
- self.pipeline = pipeline(
314
- "text-generation",
315
- model=self.model_name,
316
- tokenizer=self.model_name,
317
- device_map="auto",
318
- load_in_8bit=True, # Reduces ~14GB to ~7GB
319
- trust_remote_code=True,
320
- torch_dtype=torch.float16
321
- )
322
- else:
323
- # CPU fallback
324
- self.pipeline = pipeline(
325
- "text-generation",
326
- model=self.model_name,
327
- tokenizer=self.model_name,
328
- device=-1, # CPU
329
- trust_remote_code=True,
330
- torch_dtype=torch.float32
331
- )
332
-
333
- # === TRANSLATION (NLLB-200, M2M-100, etc.) ===
334
- elif self.task == "translation":
335
- self.pipeline = pipeline(
336
- "translation",
337
- model=self.model_name,
338
- device=0 if self.device == DeviceType.CUDA.value else -1,
339
- src_lang=self.config.get("default_src_lang", "eng_Latn"),
340
- tgt_lang=self.config.get("default_tgt_lang", "spa_Latn")
341
- )
342
-
343
- # === SENTIMENT ANALYSIS ===
344
- elif self.task == "sentiment-analysis":
345
- self.pipeline = pipeline(
346
- "sentiment-analysis",
347
- model=self.model_name,
348
- device=0 if self.device == DeviceType.CUDA.value else -1,
349
- truncation=True,
350
- max_length=512
351
- )
352
-
353
- # === BIAS DETECTION (Zero-Shot Classification) ===
354
- elif self.task == "bias-detection":
355
- self.pipeline = pipeline(
356
- "zero-shot-classification",
357
- model=self.model_name,
358
- device=0 if self.device == DeviceType.CUDA.value else -1
359
- )
360
-
361
- # === TEXT CLASSIFICATION (Generic) ===
362
- elif self.task == "text-classification":
363
- self.pipeline = pipeline(
364
- "text-classification",
365
- model=self.model_name,
366
- device=0 if self.device == DeviceType.CUDA.value else -1,
367
- truncation=True
368
- )
369
-
370
- # === PDF/DOCUMENT EXTRACTION (LayoutLMv3) ===
371
- elif self.task == "pdf-extraction":
372
- logger.warning("⚠️ PDF extraction requires additional OCR setup")
373
- logger.info(" Consider using Azure Form Recognizer as alternative")
374
- # Placeholder - requires pytesseract/OCR infrastructure
375
- self.pipeline = None
376
- return False
377
-
378
- else:
379
- raise ValueError(f"Unknown task type: {self.task}")
380
-
381
- # === SUCCESS HANDLING ===
382
- if self.pipeline is not None:
383
- # Calculate load time
384
- load_time = (datetime.now() - start_time).total_seconds()
385
- self.metadata.loaded_at = datetime.now()
386
- self.metadata.load_time_seconds = load_time
387
-
388
- # Cache the pipeline
389
- _MODEL_CACHE[self.name] = self.pipeline
390
- _LOAD_TIMES[self.name] = load_time
391
-
392
- # Log memory usage
393
- mem_stats = get_memory_stats()
394
- self.metadata.memory_usage_gb = mem_stats.get("gpu_allocated_gb", 0)
395
-
396
- logger.info(f"✅ {self.name} loaded successfully!")
397
- logger.info(f" Load time: {load_time:.2f}s")
398
-
399
- if "gpu_allocated_gb" in mem_stats:
400
- logger.info(
401
- f" GPU Memory: {mem_stats['gpu_allocated_gb']:.2f}GB / "
402
- f"{mem_stats['gpu_total_gb']:.2f}GB"
403
- )
404
-
405
- return True
406
-
407
- except Exception as e:
408
- logger.error(f"❌ Failed to load {self.name}: {e}", exc_info=True)
409
- self.pipeline = None
410
- return False
411
-
412
- def predict(
413
- self,
414
- input_data: Union[str, Dict[str, Any]],
415
- **kwargs
416
- ) -> Dict[str, Any]:
417
- """
418
- 🎯 Runs inference with the loaded model pipeline.
419
-
420
- Features:
421
- - Automatic pipeline loading
422
- - Error handling with fallback responses
423
- - Performance tracking
424
- - Penny's personality injection (for text-generation)
425
-
426
- Args:
427
- input_data: Text or structured input for the model
428
- **kwargs: Task-specific parameters
429
-
430
- Returns:
431
- Model output dict with results or error information
432
- """
433
- # Track inference start time
434
- start_time = datetime.now()
435
-
436
- # Ensure pipeline is loaded
437
- if self.pipeline is None:
438
- success = self.load_pipeline()
439
- if not success:
440
- return {
441
- "error": f"{self.name} pipeline unavailable",
442
- "detail": "Model failed to load. Check logs for details.",
443
- "model": self.name
444
- }
445
-
446
- try:
447
- # === TEXT GENERATION ===
448
- if self.task == "text-generation":
449
- # Inject Penny's civic identity
450
- if not kwargs.get("skip_system_prompt", False):
451
- full_prompt = PENNY_SYSTEM_PROMPT + input_data
452
- else:
453
- full_prompt = input_data
454
-
455
- # Extract generation parameters with safe defaults
456
- max_new_tokens = kwargs.get("max_new_tokens", 256)
457
- temperature = kwargs.get("temperature", 0.7)
458
- top_p = kwargs.get("top_p", 0.9)
459
- do_sample = kwargs.get("do_sample", temperature > 0.0)
460
-
461
- result = self.pipeline(
462
- full_prompt,
463
- max_new_tokens=max_new_tokens,
464
- temperature=temperature,
465
- top_p=top_p,
466
- do_sample=do_sample,
467
- return_full_text=False,
468
- pad_token_id=self.pipeline.tokenizer.eos_token_id,
469
- truncation=True
470
- )
471
-
472
- output = {
473
- "generated_text": result[0]["generated_text"],
474
- "model": self.name,
475
- "success": True
476
- }
477
-
478
- # === TRANSLATION ===
479
- elif self.task == "translation":
480
- src_lang = kwargs.get("source_lang", "eng_Latn")
481
- tgt_lang = kwargs.get("target_lang", "spa_Latn")
482
-
483
- result = self.pipeline(
484
- input_data,
485
- src_lang=src_lang,
486
- tgt_lang=tgt_lang,
487
- max_length=512
488
- )
489
-
490
- output = {
491
- "translation": result[0]["translation_text"],
492
- "source_lang": src_lang,
493
- "target_lang": tgt_lang,
494
- "model": self.name,
495
- "success": True
496
- }
497
-
498
- # === SENTIMENT ANALYSIS ===
499
- elif self.task == "sentiment-analysis":
500
- result = self.pipeline(input_data)
501
-
502
- output = {
503
- "sentiment": result[0]["label"],
504
- "confidence": result[0]["score"],
505
- "model": self.name,
506
- "success": True
507
- }
508
-
509
- # === BIAS DETECTION ===
510
- elif self.task == "bias-detection":
511
- candidate_labels = kwargs.get("candidate_labels", [
512
- "neutral and objective",
513
- "contains political bias",
514
- "uses emotional language",
515
- "culturally insensitive"
516
- ])
517
-
518
- result = self.pipeline(
519
- input_data,
520
- candidate_labels=candidate_labels,
521
- multi_label=True
522
- )
523
-
524
- output = {
525
- "labels": result["labels"],
526
- "scores": result["scores"],
527
- "model": self.name,
528
- "success": True
529
- }
530
-
531
- # === TEXT CLASSIFICATION ===
532
- elif self.task == "text-classification":
533
- result = self.pipeline(input_data)
534
-
535
- output = {
536
- "label": result[0]["label"],
537
- "confidence": result[0]["score"],
538
- "model": self.name,
539
- "success": True
540
- }
541
-
542
- else:
543
- output = {
544
- "error": f"Task '{self.task}' not implemented",
545
- "model": self.name,
546
- "success": False
547
- }
548
-
549
- # Track performance
550
- inference_time = (datetime.now() - start_time).total_seconds() * 1000
551
- self.metadata.inference_count += 1
552
- self.metadata.total_inference_time_ms += inference_time
553
- output["inference_time_ms"] = round(inference_time, 2)
554
-
555
- return output
556
-
557
- except Exception as e:
558
- logger.error(f"❌ Inference error in {self.name}: {e}", exc_info=True)
559
- return {
560
- "error": "Inference failed",
561
- "detail": str(e),
562
- "model": self.name,
563
- "success": False
564
- }
565
-
566
- def unload(self) -> None:
567
- """
568
- 🗑️ Unloads the model to free memory.
569
- Critical for Azure environments with limited resources.
570
- """
571
- if self.pipeline is not None:
572
- logger.info(f"🗑️ Unloading {self.name}...")
573
-
574
- # Delete pipeline
575
- del self.pipeline
576
- self.pipeline = None
577
-
578
- # Remove from cache
579
- if self.name in _MODEL_CACHE:
580
- del _MODEL_CACHE[self.name]
581
-
582
- # Force GPU memory release
583
- if torch.cuda.is_available():
584
- torch.cuda.empty_cache()
585
-
586
- logger.info(f"✅ {self.name} unloaded successfully")
587
-
588
- # Log memory stats after unload
589
- mem_stats = get_memory_stats()
590
- if "gpu_allocated_gb" in mem_stats:
591
- logger.info(f" GPU Memory: {mem_stats['gpu_allocated_gb']:.2f}GB remaining")
592
-
593
- def get_metadata(self) -> Dict[str, Any]:
594
- """
595
- 📊 Returns model metadata and performance stats.
596
- """
597
- return {
598
- "name": self.metadata.name,
599
- "task": self.metadata.task,
600
- "model_name": self.metadata.model_name,
601
- "device": self.metadata.device,
602
- "loaded": self.pipeline is not None,
603
- "loaded_at": self.metadata.loaded_at.isoformat() if self.metadata.loaded_at else None,
604
- "load_time_seconds": self.metadata.load_time_seconds,
605
- "memory_usage_gb": self.metadata.memory_usage_gb,
606
- "inference_count": self.metadata.inference_count,
607
- "avg_inference_time_ms": round(self.metadata.avg_inference_time_ms, 2)
608
- }
609
-
610
-
611
- # ============================================================
612
- # MODEL LOADER (Singleton Manager)
613
- # ============================================================
614
-
615
- class ModelLoader:
616
- """
617
- 🎛️ Singleton manager for all Penny's specialized models.
618
-
619
- Features:
620
- - Centralized model configuration
621
- - Lazy loading (models only load when needed)
622
- - Memory management
623
- - Health monitoring
624
- - Unified access interface
625
- """
626
-
627
- _instance: Optional['ModelLoader'] = None
628
-
629
- def __new__(cls, *args, **kwargs):
630
- """Singleton pattern - only one ModelLoader instance."""
631
- if cls._instance is None:
632
- cls._instance = super(ModelLoader, cls).__new__(cls)
633
- return cls._instance
634
-
635
- def __init__(self, config_path: Optional[str] = None):
636
- """
637
- Initialize ModelLoader (only runs once due to singleton).
638
-
639
- Args:
640
- config_path: Path to model_config.json (optional)
641
- """
642
- if not hasattr(self, '_models_loaded'):
643
- self.models: Dict[str, ModelClient] = {}
644
- self._models_loaded = True
645
- self._initialization_time = datetime.now()
646
-
647
- # Use provided path or default
648
- config_file = Path(config_path) if config_path else CONFIG_PATH
649
-
650
- try:
651
- logger.info(f"📖 Loading model configuration from {config_file}")
652
-
653
- if not config_file.exists():
654
- logger.warning(f"⚠️ Configuration file not found: {config_file}")
655
- logger.info(" Create model_config.json with your model definitions")
656
- return
657
-
658
- with open(config_file, "r") as f:
659
- config = json.load(f)
660
-
661
- # Initialize ModelClients (doesn't load models yet)
662
- for model_id, model_info in config.items():
663
- self.models[model_id] = ModelClient(
664
- name=model_id,
665
- model_name=model_info["model_name"],
666
- task=model_info["task"],
667
- config=model_info.get("config", {})
668
- )
669
-
670
- logger.info(f"✅ ModelLoader initialized with {len(self.models)} models:")
671
- for model_id in self.models.keys():
672
- logger.info(f" - {model_id}")
673
-
674
- except json.JSONDecodeError as e:
675
- logger.error(f"❌ Invalid JSON in model_config.json: {e}")
676
- except Exception as e:
677
- logger.error(f"❌ Failed to initialize ModelLoader: {e}", exc_info=True)
678
-
679
- def get(self, model_id: str) -> Optional[ModelClient]:
680
- """
681
- 🎯 Retrieves a configured ModelClient by ID.
682
-
683
- Args:
684
- model_id: Model identifier from config
685
-
686
- Returns:
687
- ModelClient instance or None if not found
688
- """
689
- return self.models.get(model_id)
690
-
691
- def list_models(self) -> List[str]:
692
- """📋 Returns list of all available model IDs."""
693
- return list(self.models.keys())
694
-
695
- def get_loaded_models(self) -> List[str]:
696
- """📋 Returns list of currently loaded model IDs."""
697
- return [
698
- model_id
699
- for model_id, client in self.models.items()
700
- if client.pipeline is not None
701
- ]
702
-
703
- def unload_all(self) -> None:
704
- """
705
- 🗑️ Unloads all models to free memory.
706
- Useful for Azure environments when switching workloads.
707
- """
708
- logger.info("🗑️ Unloading all models...")
709
- for model_client in self.models.values():
710
- model_client.unload()
711
- logger.info("✅ All models unloaded")
712
-
713
- def get_status(self) -> Dict[str, Any]:
714
- """
715
- 📊 Returns comprehensive status of all models.
716
- Useful for health checks and monitoring.
717
- """
718
- status = {
719
- "initialization_time": self._initialization_time.isoformat(),
720
- "total_models": len(self.models),
721
- "loaded_models": len(self.get_loaded_models()),
722
- "device": get_optimal_device(),
723
- "memory": get_memory_stats(),
724
- "models": {}
725
- }
726
-
727
- for model_id, client in self.models.items():
728
- status["models"][model_id] = client.get_metadata()
729
-
730
- return status
731
-
732
-
733
- # ============================================================
734
- # PUBLIC INTERFACE (Used by all *_utils.py modules)
735
- # ============================================================
736
-
737
- def load_model_pipeline(agent_name: str) -> Callable[..., Dict[str, Any]]:
738
- """
739
- 🚀 Loads a model client and returns its inference function.
740
-
741
- This is the main function used by other modules (translation_utils.py,
742
- sentiment_utils.py, etc.) to access Penny's models.
743
-
744
- Args:
745
- agent_name: Model ID from model_config.json
746
-
747
- Returns:
748
- Callable inference function
749
-
750
- Raises:
751
- ValueError: If agent_name not found in configuration
752
-
753
- Example:
754
- >>> translator = load_model_pipeline("penny-translate-agent")
755
- >>> result = translator("Hello world", target_lang="spa_Latn")
756
- """
757
- loader = ModelLoader()
758
- client = loader.get(agent_name)
759
-
760
- if client is None:
761
- available = loader.list_models()
762
- raise ValueError(
763
- f"Agent ID '{agent_name}' not found in model configuration. "
764
- f"Available models: {available}"
765
- )
766
-
767
- # Load the pipeline (lazy loading)
768
- client.load_pipeline()
769
-
770
- # Return a callable wrapper
771
- def inference_wrapper(input_data, **kwargs):
772
- return client.predict(input_data, **kwargs)
773
-
774
- return inference_wrapper
775
-
776
-
777
- # === CONVENIENCE FUNCTIONS ===
778
-
779
- def get_model_status() -> Dict[str, Any]:
780
- """
781
- 📊 Returns status of all configured models.
782
- Useful for health checks and monitoring endpoints.
783
- """
784
- loader = ModelLoader()
785
- return loader.get_status()
786
-
787
-
788
- def preload_models(model_ids: Optional[List[str]] = None) -> None:
789
- """
790
- 🚀 Preloads specified models during startup.
791
-
792
- Args:
793
- model_ids: List of model IDs to preload (None = all models)
794
- """
795
- loader = ModelLoader()
796
-
797
- if model_ids is None:
798
- model_ids = loader.list_models()
799
-
800
- logger.info(f"🚀 Preloading {len(model_ids)} models...")
801
-
802
- for model_id in model_ids:
803
- client = loader.get(model_id)
804
- if client:
805
- logger.info(f" Loading {model_id}...")
806
- client.load_pipeline()
807
-
808
- logger.info("✅ Model preloading complete")
809
-
810
-
811
- def initialize_model_system() -> bool:
812
- """
813
- 🏁 Initializes the model system.
814
- Should be called during app startup.
815
-
816
- Returns:
817
- True if initialization successful
818
- """
819
- logger.info("🧠 Initializing Penny's model system...")
820
-
821
- try:
822
- # Initialize singleton
823
- loader = ModelLoader()
824
-
825
- # Log device info
826
- device = get_optimal_device()
827
- mem_stats = get_memory_stats()
828
-
829
- logger.info(f"✅ Model system initialized")
830
- logger.info(f"🎮 Compute device: {device}")
831
-
832
- if "gpu_total_gb" in mem_stats:
833
- logger.info(
834
- f"💾 GPU Memory: {mem_stats['gpu_total_gb']:.1f}GB total"
835
- )
836
-
837
- logger.info(f"📦 {len(loader.models)} models configured")
838
-
839
- # Optional: Preload critical models
840
- # Uncomment to preload models at startup
841
- # preload_models(["penny-core-agent"])
842
-
843
- return True
844
-
845
- except Exception as e:
846
- logger.error(f"❌ Failed to initialize model system: {e}", exc_info=True)
847
- return False
848
-
849
-
850
- # ============================================================
851
- # CLI TESTING & DEBUGGING
852
- # ============================================================
853
-
854
- if __name__ == "__main__":
855
- """
856
- 🧪 Test script for model loading and inference.
857
- Run with: python -m app.model_loader
858
- """
859
- print("=" * 60)
860
- print("🧪 Testing Penny's Model System")
861
- print("=" * 60)
862
-
863
- # Initialize
864
- loader = ModelLoader()
865
- print(f"\n📋 Available models: {loader.list_models()}")
866
-
867
- # Get status
868
- status = get_model_status()
869
- print(f"\n📊 System status:")
870
- print(json.dumps(status, indent=2, default=str))
871
-
872
- # Test model loading (if models configured)
873
- if loader.models:
874
- test_model_id = list(loader.models.keys())[0]
875
- print(f"\n🧪 Testing model: {test_model_id}")
876
-
877
- client = loader.get(test_model_id)
878
- if client:
879
- print(f" Loading pipeline...")
880
- success = client.load_pipeline()
881
-
882
- if success:
883
- print(f" ✅ Model loaded successfully!")
884
- print(f" Metadata: {json.dumps(client.get_metadata(), indent=2, default=str)}")
885
- else:
886
- print(f" ❌ Model loading failed")