pythonprincess commited on
Commit
1bfa24a
·
verified ·
1 Parent(s): 08067f2

Delete model_loader.py

Browse files
Files changed (1) hide show
  1. model_loader.py +0 -861
model_loader.py DELETED
@@ -1,861 +0,0 @@
1
- # app/model_loader.py
2
- """
3
- 🧠 PENNY Model Loader - Azure-Ready Multi-Model Orchestration
4
-
5
- This is Penny's brain loader. She manages multiple specialized models:
6
- - Gemma 7B for conversational reasoning
7
- - NLLB-200 for 27-language translation
8
- - Sentiment analysis for resident wellbeing
9
- - Bias detection for equitable service
10
- - LayoutLM for civic document processing
11
-
12
- MISSION: Load AI models efficiently in memory-constrained environments while
13
- maintaining Penny's warm, civic-focused personality across all interactions.
14
-
15
- FEATURES:
16
- - Lazy loading (models only load when needed)
17
- - 8-bit quantization for memory efficiency
18
- - GPU/CPU auto-detection
19
- - Model caching and reuse
20
- - Graceful fallbacks for Azure ML deployment
21
- - Memory monitoring and cleanup
22
- """
23
-
24
- import json
25
- import os
26
- import torch
27
- from typing import Dict, Any, Callable, Optional, Union, List
28
- from pathlib import Path
29
- import logging
30
- from dataclasses import dataclass
31
- from enum import Enum
32
- from datetime import datetime
33
-
34
- from transformers import (
35
- AutoTokenizer,
36
- AutoModelForCausalLM,
37
- AutoModelForSeq2SeqLM,
38
- pipeline,
39
- PreTrainedModel,
40
- PreTrainedTokenizer
41
- )
42
-
43
- # --- LOGGING SETUP ---
44
- logger = logging.getLogger(__name__)
45
-
46
- # --- PATH CONFIGURATION (Environment-Aware) ---
47
- # Support both local development and Azure ML deployment
48
- if os.getenv("AZUREML_MODEL_DIR"):
49
- # Azure ML deployment - models are in AZUREML_MODEL_DIR
50
- MODEL_ROOT = Path(os.getenv("AZUREML_MODEL_DIR"))
51
- CONFIG_PATH = MODEL_ROOT / "model_config.json"
52
- logger.info("☁️ Running in Azure ML environment")
53
- else:
54
- # Local development - models are in project structure
55
- PROJECT_ROOT = Path(__file__).parent.parent
56
- MODEL_ROOT = PROJECT_ROOT / "models"
57
- CONFIG_PATH = MODEL_ROOT / "model_config.json"
58
- logger.info("💻 Running in local development environment")
59
-
60
- logger.info(f"📂 Model config path: {CONFIG_PATH}")
61
-
62
- # ============================================================
63
- # PENNY'S CIVIC IDENTITY & PERSONALITY
64
- # ============================================================
65
-
66
- PENNY_SYSTEM_PROMPT = (
67
- "You are Penny, a smart, civic-focused AI assistant serving local communities. "
68
- "You help residents navigate city services, government programs, and community resources. "
69
- "You're warm, professional, accurate, and always stay within your civic mission.\n\n"
70
-
71
- "Your expertise includes:\n"
72
- "- Connecting people with local services (food banks, shelters, libraries)\n"
73
- "- Translating information into 27 languages\n"
74
- "- Explaining public programs and eligibility\n"
75
- "- Guiding residents through civic processes\n"
76
- "- Providing emergency resources when needed\n\n"
77
-
78
- "YOUR PERSONALITY:\n"
79
- "- Warm and approachable, like a helpful community center staff member\n"
80
- "- Clear and practical, avoiding jargon\n"
81
- "- Culturally sensitive and inclusive\n"
82
- "- Patient with repetition or clarification\n"
83
- "- Funny when appropriate, but never at anyone's expense\n\n"
84
-
85
- "CRITICAL RULES:\n"
86
- "- When residents greet you by name (e.g., 'Hi Penny'), respond warmly and personally\n"
87
- "- You are ALWAYS Penny - never ChatGPT, Assistant, Claude, or any other name\n"
88
- "- If you don't know something, say so clearly and help find the right resource\n"
89
- "- NEVER make up information about services, eligibility, or contacts\n"
90
- "- Stay within your civic mission - you don't provide legal, medical, or financial advice\n"
91
- "- For emergencies, immediately connect to appropriate services (911, crisis lines)\n\n"
92
- )
93
-
94
- # --- GLOBAL STATE ---
95
- _MODEL_CACHE: Dict[str, Any] = {} # Memory-efficient model reuse
96
- _LOAD_TIMES: Dict[str, float] = {} # Track model loading performance
97
-
98
-
99
- # ============================================================
100
- # DEVICE MANAGEMENT
101
- # ============================================================
102
-
103
- class DeviceType(str, Enum):
104
- """Supported compute devices."""
105
- CUDA = "cuda"
106
- CPU = "cpu"
107
- MPS = "mps" # Apple Silicon
108
-
109
-
110
- def get_optimal_device() -> str:
111
- """
112
- 🎮 Determines the best device for model inference.
113
-
114
- Priority:
115
- 1. CUDA GPU (NVIDIA)
116
- 2. MPS (Apple Silicon)
117
- 3. CPU (fallback)
118
-
119
- Returns:
120
- Device string ("cuda", "mps", or "cpu")
121
- """
122
- if torch.cuda.is_available():
123
- device = DeviceType.CUDA.value
124
- gpu_name = torch.cuda.get_device_name(0)
125
- gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
126
- logger.info(f"🎮 GPU detected: {gpu_name} ({gpu_memory:.1f}GB)")
127
- return device
128
-
129
- elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
130
- device = DeviceType.MPS.value
131
- logger.info("🍎 Apple Silicon (MPS) detected")
132
- return device
133
-
134
- else:
135
- device = DeviceType.CPU.value
136
- logger.info("💻 Using CPU for inference")
137
- logger.warning("⚠️ GPU not available - inference will be slower")
138
- return device
139
-
140
-
141
- def get_memory_stats() -> Dict[str, float]:
142
- """
143
- 📊 Returns current GPU/CPU memory statistics.
144
-
145
- Returns:
146
- Dict with memory stats in GB
147
- """
148
- stats = {}
149
-
150
- if torch.cuda.is_available():
151
- stats["gpu_allocated_gb"] = torch.cuda.memory_allocated() / 1e9
152
- stats["gpu_reserved_gb"] = torch.cuda.memory_reserved() / 1e9
153
- stats["gpu_total_gb"] = torch.cuda.get_device_properties(0).total_memory / 1e9
154
-
155
- # CPU memory (requires psutil)
156
- try:
157
- import psutil
158
- mem = psutil.virtual_memory()
159
- stats["cpu_used_gb"] = mem.used / 1e9
160
- stats["cpu_total_gb"] = mem.total / 1e9
161
- stats["cpu_percent"] = mem.percent
162
- except ImportError:
163
- pass
164
-
165
- return stats
166
-
167
-
168
- # ============================================================
169
- # MODEL CLIENT (Individual Model Handler)
170
- # ============================================================
171
-
172
- @dataclass
173
- class ModelMetadata:
174
- """
175
- 📋 Metadata about a loaded model.
176
- Tracks performance and resource usage.
177
- """
178
- name: str
179
- task: str
180
- model_name: str
181
- device: str
182
- loaded_at: Optional[datetime] = None
183
- load_time_seconds: Optional[float] = None
184
- memory_usage_gb: Optional[float] = None
185
- inference_count: int = 0
186
- total_inference_time_ms: float = 0.0
187
-
188
- @property
189
- def avg_inference_time_ms(self) -> float:
190
- """Calculate average inference time."""
191
- if self.inference_count == 0:
192
- return 0.0
193
- return self.total_inference_time_ms / self.inference_count
194
-
195
-
196
- class ModelClient:
197
- """
198
- 🤖 Manages a single HuggingFace model with optimized loading and inference.
199
-
200
- Features:
201
- - Lazy loading (load on first use)
202
- - Memory optimization (8-bit quantization)
203
- - Performance tracking
204
- - Graceful error handling
205
- - Automatic device placement
206
- """
207
-
208
- def __init__(
209
- self,
210
- name: str,
211
- model_name: str,
212
- task: str,
213
- device: str = None,
214
- config: Optional[Dict[str, Any]] = None
215
- ):
216
- """
217
- Initialize model client (doesn't load the model yet).
218
-
219
- Args:
220
- name: Model identifier (e.g., "penny-core-agent")
221
- model_name: HuggingFace model ID
222
- task: Task type (text-generation, translation, etc.)
223
- device: Target device (auto-detected if None)
224
- config: Additional model configuration
225
- """
226
- self.name = name
227
- self.model_name = model_name
228
- self.task = task
229
- self.device = device or get_optimal_device()
230
- self.config = config or {}
231
- self.pipeline = None
232
- self._load_attempted = False
233
- self.metadata = ModelMetadata(
234
- name=name,
235
- task=task,
236
- model_name=model_name,
237
- device=self.device
238
- )
239
-
240
- logger.info(f"📦 Initialized ModelClient: {name}")
241
- logger.debug(f" Model: {model_name}")
242
- logger.debug(f" Task: {task}")
243
- logger.debug(f" Device: {self.device}")
244
-
245
- def load_pipeline(self) -> bool:
246
- """
247
- 🔄 Loads the HuggingFace pipeline with Azure-optimized settings.
248
-
249
- Features:
250
- - 8-bit quantization for large models (saves ~50% memory)
251
- - Automatic device placement
252
- - Memory monitoring
253
- - Cache checking
254
-
255
- Returns:
256
- True if successful, False otherwise
257
- """
258
- if self.pipeline is not None:
259
- logger.debug(f"✅ {self.name} already loaded")
260
- return True
261
-
262
- if self._load_attempted:
263
- logger.warning(f"⚠️ Previous load attempt failed for {self.name}")
264
- return False
265
-
266
- global _MODEL_CACHE, _LOAD_TIMES
267
-
268
- # Check cache first
269
- if self.name in _MODEL_CACHE:
270
- logger.info(f"♻️ Using cached pipeline for {self.name}")
271
- self.pipeline = _MODEL_CACHE[self.name]
272
- return True
273
-
274
- logger.info(f"🔄 Loading {self.name} from HuggingFace...")
275
- self._load_attempted = True
276
-
277
- start_time = datetime.now()
278
-
279
- try:
280
- # === TEXT GENERATION (Gemma 7B, GPT-2, etc.) ===
281
- if self.task == "text-generation":
282
- logger.info(" Using 8-bit quantization for memory efficiency...")
283
-
284
- # Check if model supports 8-bit loading
285
- use_8bit = self.device == DeviceType.CUDA.value
286
-
287
- if use_8bit:
288
- self.pipeline = pipeline(
289
- "text-generation",
290
- model=self.model_name,
291
- tokenizer=self.model_name,
292
- device_map="auto",
293
- load_in_8bit=True, # Reduces ~14GB to ~7GB
294
- trust_remote_code=True,
295
- torch_dtype=torch.float16
296
- )
297
- else:
298
- # CPU fallback
299
- self.pipeline = pipeline(
300
- "text-generation",
301
- model=self.model_name,
302
- tokenizer=self.model_name,
303
- device=-1, # CPU
304
- trust_remote_code=True,
305
- torch_dtype=torch.float32
306
- )
307
-
308
- # === TRANSLATION (NLLB-200, M2M-100, etc.) ===
309
- elif self.task == "translation":
310
- self.pipeline = pipeline(
311
- "translation",
312
- model=self.model_name,
313
- device=0 if self.device == DeviceType.CUDA.value else -1,
314
- src_lang=self.config.get("default_src_lang", "eng_Latn"),
315
- tgt_lang=self.config.get("default_tgt_lang", "spa_Latn")
316
- )
317
-
318
- # === SENTIMENT ANALYSIS ===
319
- elif self.task == "sentiment-analysis":
320
- self.pipeline = pipeline(
321
- "sentiment-analysis",
322
- model=self.model_name,
323
- device=0 if self.device == DeviceType.CUDA.value else -1,
324
- truncation=True,
325
- max_length=512
326
- )
327
-
328
- # === BIAS DETECTION (Zero-Shot Classification) ===
329
- elif self.task == "bias-detection":
330
- self.pipeline = pipeline(
331
- "zero-shot-classification",
332
- model=self.model_name,
333
- device=0 if self.device == DeviceType.CUDA.value else -1
334
- )
335
-
336
- # === TEXT CLASSIFICATION (Generic) ===
337
- elif self.task == "text-classification":
338
- self.pipeline = pipeline(
339
- "text-classification",
340
- model=self.model_name,
341
- device=0 if self.device == DeviceType.CUDA.value else -1,
342
- truncation=True
343
- )
344
-
345
- # === PDF/DOCUMENT EXTRACTION (LayoutLMv3) ===
346
- elif self.task == "pdf-extraction":
347
- logger.warning("⚠️ PDF extraction requires additional OCR setup")
348
- logger.info(" Consider using Azure Form Recognizer as alternative")
349
- # Placeholder - requires pytesseract/OCR infrastructure
350
- self.pipeline = None
351
- return False
352
-
353
- else:
354
- raise ValueError(f"Unknown task type: {self.task}")
355
-
356
- # === SUCCESS HANDLING ===
357
- if self.pipeline is not None:
358
- # Calculate load time
359
- load_time = (datetime.now() - start_time).total_seconds()
360
- self.metadata.loaded_at = datetime.now()
361
- self.metadata.load_time_seconds = load_time
362
-
363
- # Cache the pipeline
364
- _MODEL_CACHE[self.name] = self.pipeline
365
- _LOAD_TIMES[self.name] = load_time
366
-
367
- # Log memory usage
368
- mem_stats = get_memory_stats()
369
- self.metadata.memory_usage_gb = mem_stats.get("gpu_allocated_gb", 0)
370
-
371
- logger.info(f"✅ {self.name} loaded successfully!")
372
- logger.info(f" Load time: {load_time:.2f}s")
373
-
374
- if "gpu_allocated_gb" in mem_stats:
375
- logger.info(
376
- f" GPU Memory: {mem_stats['gpu_allocated_gb']:.2f}GB / "
377
- f"{mem_stats['gpu_total_gb']:.2f}GB"
378
- )
379
-
380
- return True
381
-
382
- except Exception as e:
383
- logger.error(f"❌ Failed to load {self.name}: {e}", exc_info=True)
384
- self.pipeline = None
385
- return False
386
-
387
- def predict(
388
- self,
389
- input_data: Union[str, Dict[str, Any]],
390
- **kwargs
391
- ) -> Dict[str, Any]:
392
- """
393
- 🎯 Runs inference with the loaded model pipeline.
394
-
395
- Features:
396
- - Automatic pipeline loading
397
- - Error handling with fallback responses
398
- - Performance tracking
399
- - Penny's personality injection (for text-generation)
400
-
401
- Args:
402
- input_data: Text or structured input for the model
403
- **kwargs: Task-specific parameters
404
-
405
- Returns:
406
- Model output dict with results or error information
407
- """
408
- # Track inference start time
409
- start_time = datetime.now()
410
-
411
- # Ensure pipeline is loaded
412
- if self.pipeline is None:
413
- success = self.load_pipeline()
414
- if not success:
415
- return {
416
- "error": f"{self.name} pipeline unavailable",
417
- "detail": "Model failed to load. Check logs for details.",
418
- "model": self.name
419
- }
420
-
421
- try:
422
- # === TEXT GENERATION ===
423
- if self.task == "text-generation":
424
- # Inject Penny's civic identity
425
- if not kwargs.get("skip_system_prompt", False):
426
- full_prompt = PENNY_SYSTEM_PROMPT + input_data
427
- else:
428
- full_prompt = input_data
429
-
430
- # Extract generation parameters with safe defaults
431
- max_new_tokens = kwargs.get("max_new_tokens", 256)
432
- temperature = kwargs.get("temperature", 0.7)
433
- top_p = kwargs.get("top_p", 0.9)
434
- do_sample = kwargs.get("do_sample", temperature > 0.0)
435
-
436
- result = self.pipeline(
437
- full_prompt,
438
- max_new_tokens=max_new_tokens,
439
- temperature=temperature,
440
- top_p=top_p,
441
- do_sample=do_sample,
442
- return_full_text=False,
443
- pad_token_id=self.pipeline.tokenizer.eos_token_id,
444
- truncation=True
445
- )
446
-
447
- output = {
448
- "generated_text": result[0]["generated_text"],
449
- "model": self.name,
450
- "success": True
451
- }
452
-
453
- # === TRANSLATION ===
454
- elif self.task == "translation":
455
- src_lang = kwargs.get("source_lang", "eng_Latn")
456
- tgt_lang = kwargs.get("target_lang", "spa_Latn")
457
-
458
- result = self.pipeline(
459
- input_data,
460
- src_lang=src_lang,
461
- tgt_lang=tgt_lang,
462
- max_length=512
463
- )
464
-
465
- output = {
466
- "translation": result[0]["translation_text"],
467
- "source_lang": src_lang,
468
- "target_lang": tgt_lang,
469
- "model": self.name,
470
- "success": True
471
- }
472
-
473
- # === SENTIMENT ANALYSIS ===
474
- elif self.task == "sentiment-analysis":
475
- result = self.pipeline(input_data)
476
-
477
- output = {
478
- "sentiment": result[0]["label"],
479
- "confidence": result[0]["score"],
480
- "model": self.name,
481
- "success": True
482
- }
483
-
484
- # === BIAS DETECTION ===
485
- elif self.task == "bias-detection":
486
- candidate_labels = kwargs.get("candidate_labels", [
487
- "neutral and objective",
488
- "contains political bias",
489
- "uses emotional language",
490
- "culturally insensitive"
491
- ])
492
-
493
- result = self.pipeline(
494
- input_data,
495
- candidate_labels=candidate_labels,
496
- multi_label=True
497
- )
498
-
499
- output = {
500
- "labels": result["labels"],
501
- "scores": result["scores"],
502
- "model": self.name,
503
- "success": True
504
- }
505
-
506
- # === TEXT CLASSIFICATION ===
507
- elif self.task == "text-classification":
508
- result = self.pipeline(input_data)
509
-
510
- output = {
511
- "label": result[0]["label"],
512
- "confidence": result[0]["score"],
513
- "model": self.name,
514
- "success": True
515
- }
516
-
517
- else:
518
- output = {
519
- "error": f"Task '{self.task}' not implemented",
520
- "model": self.name,
521
- "success": False
522
- }
523
-
524
- # Track performance
525
- inference_time = (datetime.now() - start_time).total_seconds() * 1000
526
- self.metadata.inference_count += 1
527
- self.metadata.total_inference_time_ms += inference_time
528
- output["inference_time_ms"] = round(inference_time, 2)
529
-
530
- return output
531
-
532
- except Exception as e:
533
- logger.error(f"❌ Inference error in {self.name}: {e}", exc_info=True)
534
- return {
535
- "error": "Inference failed",
536
- "detail": str(e),
537
- "model": self.name,
538
- "success": False
539
- }
540
-
541
- def unload(self) -> None:
542
- """
543
- 🗑️ Unloads the model to free memory.
544
- Critical for Azure environments with limited resources.
545
- """
546
- if self.pipeline is not None:
547
- logger.info(f"🗑️ Unloading {self.name}...")
548
-
549
- # Delete pipeline
550
- del self.pipeline
551
- self.pipeline = None
552
-
553
- # Remove from cache
554
- if self.name in _MODEL_CACHE:
555
- del _MODEL_CACHE[self.name]
556
-
557
- # Force GPU memory release
558
- if torch.cuda.is_available():
559
- torch.cuda.empty_cache()
560
-
561
- logger.info(f"✅ {self.name} unloaded successfully")
562
-
563
- # Log memory stats after unload
564
- mem_stats = get_memory_stats()
565
- if "gpu_allocated_gb" in mem_stats:
566
- logger.info(f" GPU Memory: {mem_stats['gpu_allocated_gb']:.2f}GB remaining")
567
-
568
- def get_metadata(self) -> Dict[str, Any]:
569
- """
570
- 📊 Returns model metadata and performance stats.
571
- """
572
- return {
573
- "name": self.metadata.name,
574
- "task": self.metadata.task,
575
- "model_name": self.metadata.model_name,
576
- "device": self.metadata.device,
577
- "loaded": self.pipeline is not None,
578
- "loaded_at": self.metadata.loaded_at.isoformat() if self.metadata.loaded_at else None,
579
- "load_time_seconds": self.metadata.load_time_seconds,
580
- "memory_usage_gb": self.metadata.memory_usage_gb,
581
- "inference_count": self.metadata.inference_count,
582
- "avg_inference_time_ms": round(self.metadata.avg_inference_time_ms, 2)
583
- }
584
-
585
-
586
- # ============================================================
587
- # MODEL LOADER (Singleton Manager)
588
- # ============================================================
589
-
590
- class ModelLoader:
591
- """
592
- 🎛️ Singleton manager for all Penny's specialized models.
593
-
594
- Features:
595
- - Centralized model configuration
596
- - Lazy loading (models only load when needed)
597
- - Memory management
598
- - Health monitoring
599
- - Unified access interface
600
- """
601
-
602
- _instance: Optional['ModelLoader'] = None
603
-
604
- def __new__(cls, *args, **kwargs):
605
- """Singleton pattern - only one ModelLoader instance."""
606
- if cls._instance is None:
607
- cls._instance = super(ModelLoader, cls).__new__(cls)
608
- return cls._instance
609
-
610
- def __init__(self, config_path: Optional[str] = None):
611
- """
612
- Initialize ModelLoader (only runs once due to singleton).
613
-
614
- Args:
615
- config_path: Path to model_config.json (optional)
616
- """
617
- if not hasattr(self, '_models_loaded'):
618
- self.models: Dict[str, ModelClient] = {}
619
- self._models_loaded = True
620
- self._initialization_time = datetime.now()
621
-
622
- # Use provided path or default
623
- config_file = Path(config_path) if config_path else CONFIG_PATH
624
-
625
- try:
626
- logger.info(f"📖 Loading model configuration from {config_file}")
627
-
628
- if not config_file.exists():
629
- logger.warning(f"⚠️ Configuration file not found: {config_file}")
630
- logger.info(" Create model_config.json with your model definitions")
631
- return
632
-
633
- with open(config_file, "r") as f:
634
- config = json.load(f)
635
-
636
- # Initialize ModelClients (doesn't load models yet)
637
- for model_id, model_info in config.items():
638
- self.models[model_id] = ModelClient(
639
- name=model_id,
640
- model_name=model_info["model_name"],
641
- task=model_info["task"],
642
- config=model_info.get("config", {})
643
- )
644
-
645
- logger.info(f"✅ ModelLoader initialized with {len(self.models)} models:")
646
- for model_id in self.models.keys():
647
- logger.info(f" - {model_id}")
648
-
649
- except json.JSONDecodeError as e:
650
- logger.error(f"❌ Invalid JSON in model_config.json: {e}")
651
- except Exception as e:
652
- logger.error(f"❌ Failed to initialize ModelLoader: {e}", exc_info=True)
653
-
654
- def get(self, model_id: str) -> Optional[ModelClient]:
655
- """
656
- 🎯 Retrieves a configured ModelClient by ID.
657
-
658
- Args:
659
- model_id: Model identifier from config
660
-
661
- Returns:
662
- ModelClient instance or None if not found
663
- """
664
- return self.models.get(model_id)
665
-
666
- def list_models(self) -> List[str]:
667
- """📋 Returns list of all available model IDs."""
668
- return list(self.models.keys())
669
-
670
- def get_loaded_models(self) -> List[str]:
671
- """📋 Returns list of currently loaded model IDs."""
672
- return [
673
- model_id
674
- for model_id, client in self.models.items()
675
- if client.pipeline is not None
676
- ]
677
-
678
- def unload_all(self) -> None:
679
- """
680
- 🗑️ Unloads all models to free memory.
681
- Useful for Azure environments when switching workloads.
682
- """
683
- logger.info("🗑️ Unloading all models...")
684
- for model_client in self.models.values():
685
- model_client.unload()
686
- logger.info("✅ All models unloaded")
687
-
688
- def get_status(self) -> Dict[str, Any]:
689
- """
690
- 📊 Returns comprehensive status of all models.
691
- Useful for health checks and monitoring.
692
- """
693
- status = {
694
- "initialization_time": self._initialization_time.isoformat(),
695
- "total_models": len(self.models),
696
- "loaded_models": len(self.get_loaded_models()),
697
- "device": get_optimal_device(),
698
- "memory": get_memory_stats(),
699
- "models": {}
700
- }
701
-
702
- for model_id, client in self.models.items():
703
- status["models"][model_id] = client.get_metadata()
704
-
705
- return status
706
-
707
-
708
- # ============================================================
709
- # PUBLIC INTERFACE (Used by all *_utils.py modules)
710
- # ============================================================
711
-
712
- def load_model_pipeline(agent_name: str) -> Callable[..., Dict[str, Any]]:
713
- """
714
- 🚀 Loads a model client and returns its inference function.
715
-
716
- This is the main function used by other modules (translation_utils.py,
717
- sentiment_utils.py, etc.) to access Penny's models.
718
-
719
- Args:
720
- agent_name: Model ID from model_config.json
721
-
722
- Returns:
723
- Callable inference function
724
-
725
- Raises:
726
- ValueError: If agent_name not found in configuration
727
-
728
- Example:
729
- >>> translator = load_model_pipeline("penny-translate-agent")
730
- >>> result = translator("Hello world", target_lang="spa_Latn")
731
- """
732
- loader = ModelLoader()
733
- client = loader.get(agent_name)
734
-
735
- if client is None:
736
- available = loader.list_models()
737
- raise ValueError(
738
- f"Agent ID '{agent_name}' not found in model configuration. "
739
- f"Available models: {available}"
740
- )
741
-
742
- # Load the pipeline (lazy loading)
743
- client.load_pipeline()
744
-
745
- # Return a callable wrapper
746
- def inference_wrapper(input_data, **kwargs):
747
- return client.predict(input_data, **kwargs)
748
-
749
- return inference_wrapper
750
-
751
-
752
- # === CONVENIENCE FUNCTIONS ===
753
-
754
- def get_model_status() -> Dict[str, Any]:
755
- """
756
- 📊 Returns status of all configured models.
757
- Useful for health checks and monitoring endpoints.
758
- """
759
- loader = ModelLoader()
760
- return loader.get_status()
761
-
762
-
763
- def preload_models(model_ids: Optional[List[str]] = None) -> None:
764
- """
765
- 🚀 Preloads specified models during startup.
766
-
767
- Args:
768
- model_ids: List of model IDs to preload (None = all models)
769
- """
770
- loader = ModelLoader()
771
-
772
- if model_ids is None:
773
- model_ids = loader.list_models()
774
-
775
- logger.info(f"🚀 Preloading {len(model_ids)} models...")
776
-
777
- for model_id in model_ids:
778
- client = loader.get(model_id)
779
- if client:
780
- logger.info(f" Loading {model_id}...")
781
- client.load_pipeline()
782
-
783
- logger.info("✅ Model preloading complete")
784
-
785
-
786
- def initialize_model_system() -> bool:
787
- """
788
- 🏁 Initializes the model system.
789
- Should be called during app startup.
790
-
791
- Returns:
792
- True if initialization successful
793
- """
794
- logger.info("🧠 Initializing Penny's model system...")
795
-
796
- try:
797
- # Initialize singleton
798
- loader = ModelLoader()
799
-
800
- # Log device info
801
- device = get_optimal_device()
802
- mem_stats = get_memory_stats()
803
-
804
- logger.info(f"✅ Model system initialized")
805
- logger.info(f"🎮 Compute device: {device}")
806
-
807
- if "gpu_total_gb" in mem_stats:
808
- logger.info(
809
- f"💾 GPU Memory: {mem_stats['gpu_total_gb']:.1f}GB total"
810
- )
811
-
812
- logger.info(f"📦 {len(loader.models)} models configured")
813
-
814
- # Optional: Preload critical models
815
- # Uncomment to preload models at startup
816
- # preload_models(["penny-core-agent"])
817
-
818
- return True
819
-
820
- except Exception as e:
821
- logger.error(f"❌ Failed to initialize model system: {e}", exc_info=True)
822
- return False
823
-
824
-
825
- # ============================================================
826
- # CLI TESTING & DEBUGGING
827
- # ============================================================
828
-
829
- if __name__ == "__main__":
830
- """
831
- 🧪 Test script for model loading and inference.
832
- Run with: python -m app.model_loader
833
- """
834
- print("=" * 60)
835
- print("🧪 Testing Penny's Model System")
836
- print("=" * 60)
837
-
838
- # Initialize
839
- loader = ModelLoader()
840
- print(f"\n📋 Available models: {loader.list_models()}")
841
-
842
- # Get status
843
- status = get_model_status()
844
- print(f"\n📊 System status:")
845
- print(json.dumps(status, indent=2, default=str))
846
-
847
- # Test model loading (if models configured)
848
- if loader.models:
849
- test_model_id = list(loader.models.keys())[0]
850
- print(f"\n🧪 Testing model: {test_model_id}")
851
-
852
- client = loader.get(test_model_id)
853
- if client:
854
- print(f" Loading pipeline...")
855
- success = client.load_pipeline()
856
-
857
- if success:
858
- print(f" ✅ Model loaded successfully!")
859
- print(f" Metadata: {json.dumps(client.get_metadata(), indent=2, default=str)}")
860
- else:
861
- print(f" ❌ Model loading failed")