pythonprincess commited on
Commit
db50b17
ยท
verified ยท
1 Parent(s): 4217b10

Upload model_loader.py

Browse files
Files changed (1) hide show
  1. app/model_loader.py +886 -0
app/model_loader.py ADDED
@@ -0,0 +1,886 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/model_loader.py
2
+ """
3
+ ๐Ÿง  PENNY Model Loader - Azure-Ready Multi-Model Orchestration
4
+
5
+ This is Penny's brain loader. She manages multiple specialized models:
6
+ - Gemma 7B for conversational reasoning
7
+ - NLLB-200 for 27-language translation
8
+ - Sentiment analysis for resident wellbeing
9
+ - Bias detection for equitable service
10
+ - LayoutLM for civic document processing
11
+
12
+ MISSION: Load AI models efficiently in memory-constrained environments while
13
+ maintaining Penny's warm, civic-focused personality across all interactions.
14
+
15
+ FEATURES:
16
+ - Lazy loading (models only load when needed)
17
+ - 8-bit quantization for memory efficiency
18
+ - GPU/CPU auto-detection
19
+ - Model caching and reuse
20
+ - Graceful fallbacks for Azure ML deployment
21
+ - Memory monitoring and cleanup
22
+ """
23
+
24
+ import json
25
+ import os
26
+ import torch
27
+ from typing import Dict, Any, Callable, Optional, Union, List
28
+ from pathlib import Path
29
+ import logging
30
+ from dataclasses import dataclass
31
+ from enum import Enum
32
+ from datetime import datetime
33
+
34
+ # ============================================================
35
+ # HUGGING FACE AUTHENTICATION
36
+ # ============================================================
37
+
38
+ def setup_huggingface_auth() -> bool:
39
+ """
40
+ ๐Ÿ” Authenticates with Hugging Face Hub using HF_TOKEN.
41
+
42
+ Returns:
43
+ True if authentication successful or not needed, False if failed
44
+ """
45
+ HF_TOKEN = os.getenv("HF_TOKEN")
46
+
47
+ if not HF_TOKEN:
48
+ logger.warning("โš ๏ธ HF_TOKEN not found in environment")
49
+ logger.warning(" Some models may not be accessible")
50
+ logger.warning(" Set HF_TOKEN in your environment or Hugging Face Spaces secrets")
51
+ return False
52
+
53
+ try:
54
+ from huggingface_hub import login
55
+ login(token=HF_TOKEN, add_to_git_credential=False)
56
+ logger.info("โœ… Authenticated with Hugging Face Hub")
57
+ return True
58
+ except ImportError:
59
+ logger.warning("โš ๏ธ huggingface_hub not installed, skipping authentication")
60
+ return False
61
+ except Exception as e:
62
+ logger.error(f"โŒ Failed to authenticate with Hugging Face: {e}")
63
+ return False
64
+
65
+ # Attempt authentication at module load
66
+ setup_huggingface_auth()
67
+
68
+ # --- LOGGING SETUP ---
69
+ logger = logging.getLogger(__name__)
70
+
71
+ # --- PATH CONFIGURATION (Environment-Aware) ---
72
+ # Support both local development and Azure ML deployment
73
+ if os.getenv("AZUREML_MODEL_DIR"):
74
+ # Azure ML deployment - models are in AZUREML_MODEL_DIR
75
+ MODEL_ROOT = Path(os.getenv("AZUREML_MODEL_DIR"))
76
+ CONFIG_PATH = MODEL_ROOT / "model_config.json"
77
+ logger.info("โ˜๏ธ Running in Azure ML environment")
78
+ else:
79
+ # Local development - models are in project structure
80
+ PROJECT_ROOT = Path(__file__).parent.parent
81
+ MODEL_ROOT = PROJECT_ROOT / "models"
82
+ CONFIG_PATH = MODEL_ROOT / "model_config.json"
83
+ logger.info("๐Ÿ’ป Running in local development environment")
84
+
85
+ logger.info(f"๐Ÿ“‚ Model config path: {CONFIG_PATH}")
86
+
87
+ # ============================================================
88
+ # PENNY'S CIVIC IDENTITY & PERSONALITY
89
+ # ============================================================
90
+
91
+ PENNY_SYSTEM_PROMPT = (
92
+ "You are Penny, a smart, civic-focused AI assistant serving local communities. "
93
+ "You help residents navigate city services, government programs, and community resources. "
94
+ "You're warm, professional, accurate, and always stay within your civic mission.\n\n"
95
+
96
+ "Your expertise includes:\n"
97
+ "- Connecting people with local services (food banks, shelters, libraries)\n"
98
+ "- Translating information into 27 languages\n"
99
+ "- Explaining public programs and eligibility\n"
100
+ "- Guiding residents through civic processes\n"
101
+ "- Providing emergency resources when needed\n\n"
102
+
103
+ "YOUR PERSONALITY:\n"
104
+ "- Warm and approachable, like a helpful community center staff member\n"
105
+ "- Clear and practical, avoiding jargon\n"
106
+ "- Culturally sensitive and inclusive\n"
107
+ "- Patient with repetition or clarification\n"
108
+ "- Funny when appropriate, but never at anyone's expense\n\n"
109
+
110
+ "CRITICAL RULES:\n"
111
+ "- When residents greet you by name (e.g., 'Hi Penny'), respond warmly and personally\n"
112
+ "- You are ALWAYS Penny - never ChatGPT, Assistant, Claude, or any other name\n"
113
+ "- If you don't know something, say so clearly and help find the right resource\n"
114
+ "- NEVER make up information about services, eligibility, or contacts\n"
115
+ "- Stay within your civic mission - you don't provide legal, medical, or financial advice\n"
116
+ "- For emergencies, immediately connect to appropriate services (911, crisis lines)\n\n"
117
+ )
118
+
119
+ # --- GLOBAL STATE ---
120
+ _MODEL_CACHE: Dict[str, Any] = {} # Memory-efficient model reuse
121
+ _LOAD_TIMES: Dict[str, float] = {} # Track model loading performance
122
+
123
+
124
+ # ============================================================
125
+ # DEVICE MANAGEMENT
126
+ # ============================================================
127
+
128
+ class DeviceType(str, Enum):
129
+ """Supported compute devices."""
130
+ CUDA = "cuda"
131
+ CPU = "cpu"
132
+ MPS = "mps" # Apple Silicon
133
+
134
+
135
+ def get_optimal_device() -> str:
136
+ """
137
+ ๐ŸŽฎ Determines the best device for model inference.
138
+
139
+ Priority:
140
+ 1. CUDA GPU (NVIDIA)
141
+ 2. MPS (Apple Silicon)
142
+ 3. CPU (fallback)
143
+
144
+ Returns:
145
+ Device string ("cuda", "mps", or "cpu")
146
+ """
147
+ if torch.cuda.is_available():
148
+ device = DeviceType.CUDA.value
149
+ gpu_name = torch.cuda.get_device_name(0)
150
+ gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
151
+ logger.info(f"๐ŸŽฎ GPU detected: {gpu_name} ({gpu_memory:.1f}GB)")
152
+ return device
153
+
154
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
155
+ device = DeviceType.MPS.value
156
+ logger.info("๐ŸŽ Apple Silicon (MPS) detected")
157
+ return device
158
+
159
+ else:
160
+ device = DeviceType.CPU.value
161
+ logger.info("๐Ÿ’ป Using CPU for inference")
162
+ logger.warning("โš ๏ธ GPU not available - inference will be slower")
163
+ return device
164
+
165
+
166
+ def get_memory_stats() -> Dict[str, float]:
167
+ """
168
+ ๐Ÿ“Š Returns current GPU/CPU memory statistics.
169
+
170
+ Returns:
171
+ Dict with memory stats in GB
172
+ """
173
+ stats = {}
174
+
175
+ if torch.cuda.is_available():
176
+ stats["gpu_allocated_gb"] = torch.cuda.memory_allocated() / 1e9
177
+ stats["gpu_reserved_gb"] = torch.cuda.memory_reserved() / 1e9
178
+ stats["gpu_total_gb"] = torch.cuda.get_device_properties(0).total_memory / 1e9
179
+
180
+ # CPU memory (requires psutil)
181
+ try:
182
+ import psutil
183
+ mem = psutil.virtual_memory()
184
+ stats["cpu_used_gb"] = mem.used / 1e9
185
+ stats["cpu_total_gb"] = mem.total / 1e9
186
+ stats["cpu_percent"] = mem.percent
187
+ except ImportError:
188
+ pass
189
+
190
+ return stats
191
+
192
+
193
+ # ============================================================
194
+ # MODEL CLIENT (Individual Model Handler)
195
+ # ============================================================
196
+
197
+ @dataclass
198
+ class ModelMetadata:
199
+ """
200
+ ๐Ÿ“‹ Metadata about a loaded model.
201
+ Tracks performance and resource usage.
202
+ """
203
+ name: str
204
+ task: str
205
+ model_name: str
206
+ device: str
207
+ loaded_at: Optional[datetime] = None
208
+ load_time_seconds: Optional[float] = None
209
+ memory_usage_gb: Optional[float] = None
210
+ inference_count: int = 0
211
+ total_inference_time_ms: float = 0.0
212
+
213
+ @property
214
+ def avg_inference_time_ms(self) -> float:
215
+ """Calculate average inference time."""
216
+ if self.inference_count == 0:
217
+ return 0.0
218
+ return self.total_inference_time_ms / self.inference_count
219
+
220
+
221
+ class ModelClient:
222
+ """
223
+ ๐Ÿค– Manages a single HuggingFace model with optimized loading and inference.
224
+
225
+ Features:
226
+ - Lazy loading (load on first use)
227
+ - Memory optimization (8-bit quantization)
228
+ - Performance tracking
229
+ - Graceful error handling
230
+ - Automatic device placement
231
+ """
232
+
233
+ def __init__(
234
+ self,
235
+ name: str,
236
+ model_name: str,
237
+ task: str,
238
+ device: str = None,
239
+ config: Optional[Dict[str, Any]] = None
240
+ ):
241
+ """
242
+ Initialize model client (doesn't load the model yet).
243
+
244
+ Args:
245
+ name: Model identifier (e.g., "penny-core-agent")
246
+ model_name: HuggingFace model ID
247
+ task: Task type (text-generation, translation, etc.)
248
+ device: Target device (auto-detected if None)
249
+ config: Additional model configuration
250
+ """
251
+ self.name = name
252
+ self.model_name = model_name
253
+ self.task = task
254
+ self.device = device or get_optimal_device()
255
+ self.config = config or {}
256
+ self.pipeline = None
257
+ self._load_attempted = False
258
+ self.metadata = ModelMetadata(
259
+ name=name,
260
+ task=task,
261
+ model_name=model_name,
262
+ device=self.device
263
+ )
264
+
265
+ logger.info(f"๐Ÿ“ฆ Initialized ModelClient: {name}")
266
+ logger.debug(f" Model: {model_name}")
267
+ logger.debug(f" Task: {task}")
268
+ logger.debug(f" Device: {self.device}")
269
+
270
+ def load_pipeline(self) -> bool:
271
+ """
272
+ ๐Ÿ”„ Loads the HuggingFace pipeline with Azure-optimized settings.
273
+
274
+ Features:
275
+ - 8-bit quantization for large models (saves ~50% memory)
276
+ - Automatic device placement
277
+ - Memory monitoring
278
+ - Cache checking
279
+
280
+ Returns:
281
+ True if successful, False otherwise
282
+ """
283
+ if self.pipeline is not None:
284
+ logger.debug(f"โœ… {self.name} already loaded")
285
+ return True
286
+
287
+ if self._load_attempted:
288
+ logger.warning(f"โš ๏ธ Previous load attempt failed for {self.name}")
289
+ return False
290
+
291
+ global _MODEL_CACHE, _LOAD_TIMES
292
+
293
+ # Check cache first
294
+ if self.name in _MODEL_CACHE:
295
+ logger.info(f"โ™ป๏ธ Using cached pipeline for {self.name}")
296
+ self.pipeline = _MODEL_CACHE[self.name]
297
+ return True
298
+
299
+ logger.info(f"๐Ÿ”„ Loading {self.name} from HuggingFace...")
300
+ self._load_attempted = True
301
+
302
+ start_time = datetime.now()
303
+
304
+ try:
305
+ # === TEXT GENERATION (Gemma 7B, GPT-2, etc.) ===
306
+ if self.task == "text-generation":
307
+ logger.info(" Using 8-bit quantization for memory efficiency...")
308
+
309
+ # Check if model supports 8-bit loading
310
+ use_8bit = self.device == DeviceType.CUDA.value
311
+
312
+ if use_8bit:
313
+ self.pipeline = pipeline(
314
+ "text-generation",
315
+ model=self.model_name,
316
+ tokenizer=self.model_name,
317
+ device_map="auto",
318
+ load_in_8bit=True, # Reduces ~14GB to ~7GB
319
+ trust_remote_code=True,
320
+ torch_dtype=torch.float16
321
+ )
322
+ else:
323
+ # CPU fallback
324
+ self.pipeline = pipeline(
325
+ "text-generation",
326
+ model=self.model_name,
327
+ tokenizer=self.model_name,
328
+ device=-1, # CPU
329
+ trust_remote_code=True,
330
+ torch_dtype=torch.float32
331
+ )
332
+
333
+ # === TRANSLATION (NLLB-200, M2M-100, etc.) ===
334
+ elif self.task == "translation":
335
+ self.pipeline = pipeline(
336
+ "translation",
337
+ model=self.model_name,
338
+ device=0 if self.device == DeviceType.CUDA.value else -1,
339
+ src_lang=self.config.get("default_src_lang", "eng_Latn"),
340
+ tgt_lang=self.config.get("default_tgt_lang", "spa_Latn")
341
+ )
342
+
343
+ # === SENTIMENT ANALYSIS ===
344
+ elif self.task == "sentiment-analysis":
345
+ self.pipeline = pipeline(
346
+ "sentiment-analysis",
347
+ model=self.model_name,
348
+ device=0 if self.device == DeviceType.CUDA.value else -1,
349
+ truncation=True,
350
+ max_length=512
351
+ )
352
+
353
+ # === BIAS DETECTION (Zero-Shot Classification) ===
354
+ elif self.task == "bias-detection":
355
+ self.pipeline = pipeline(
356
+ "zero-shot-classification",
357
+ model=self.model_name,
358
+ device=0 if self.device == DeviceType.CUDA.value else -1
359
+ )
360
+
361
+ # === TEXT CLASSIFICATION (Generic) ===
362
+ elif self.task == "text-classification":
363
+ self.pipeline = pipeline(
364
+ "text-classification",
365
+ model=self.model_name,
366
+ device=0 if self.device == DeviceType.CUDA.value else -1,
367
+ truncation=True
368
+ )
369
+
370
+ # === PDF/DOCUMENT EXTRACTION (LayoutLMv3) ===
371
+ elif self.task == "pdf-extraction":
372
+ logger.warning("โš ๏ธ PDF extraction requires additional OCR setup")
373
+ logger.info(" Consider using Azure Form Recognizer as alternative")
374
+ # Placeholder - requires pytesseract/OCR infrastructure
375
+ self.pipeline = None
376
+ return False
377
+
378
+ else:
379
+ raise ValueError(f"Unknown task type: {self.task}")
380
+
381
+ # === SUCCESS HANDLING ===
382
+ if self.pipeline is not None:
383
+ # Calculate load time
384
+ load_time = (datetime.now() - start_time).total_seconds()
385
+ self.metadata.loaded_at = datetime.now()
386
+ self.metadata.load_time_seconds = load_time
387
+
388
+ # Cache the pipeline
389
+ _MODEL_CACHE[self.name] = self.pipeline
390
+ _LOAD_TIMES[self.name] = load_time
391
+
392
+ # Log memory usage
393
+ mem_stats = get_memory_stats()
394
+ self.metadata.memory_usage_gb = mem_stats.get("gpu_allocated_gb", 0)
395
+
396
+ logger.info(f"โœ… {self.name} loaded successfully!")
397
+ logger.info(f" Load time: {load_time:.2f}s")
398
+
399
+ if "gpu_allocated_gb" in mem_stats:
400
+ logger.info(
401
+ f" GPU Memory: {mem_stats['gpu_allocated_gb']:.2f}GB / "
402
+ f"{mem_stats['gpu_total_gb']:.2f}GB"
403
+ )
404
+
405
+ return True
406
+
407
+ except Exception as e:
408
+ logger.error(f"โŒ Failed to load {self.name}: {e}", exc_info=True)
409
+ self.pipeline = None
410
+ return False
411
+
412
+ def predict(
413
+ self,
414
+ input_data: Union[str, Dict[str, Any]],
415
+ **kwargs
416
+ ) -> Dict[str, Any]:
417
+ """
418
+ ๐ŸŽฏ Runs inference with the loaded model pipeline.
419
+
420
+ Features:
421
+ - Automatic pipeline loading
422
+ - Error handling with fallback responses
423
+ - Performance tracking
424
+ - Penny's personality injection (for text-generation)
425
+
426
+ Args:
427
+ input_data: Text or structured input for the model
428
+ **kwargs: Task-specific parameters
429
+
430
+ Returns:
431
+ Model output dict with results or error information
432
+ """
433
+ # Track inference start time
434
+ start_time = datetime.now()
435
+
436
+ # Ensure pipeline is loaded
437
+ if self.pipeline is None:
438
+ success = self.load_pipeline()
439
+ if not success:
440
+ return {
441
+ "error": f"{self.name} pipeline unavailable",
442
+ "detail": "Model failed to load. Check logs for details.",
443
+ "model": self.name
444
+ }
445
+
446
+ try:
447
+ # === TEXT GENERATION ===
448
+ if self.task == "text-generation":
449
+ # Inject Penny's civic identity
450
+ if not kwargs.get("skip_system_prompt", False):
451
+ full_prompt = PENNY_SYSTEM_PROMPT + input_data
452
+ else:
453
+ full_prompt = input_data
454
+
455
+ # Extract generation parameters with safe defaults
456
+ max_new_tokens = kwargs.get("max_new_tokens", 256)
457
+ temperature = kwargs.get("temperature", 0.7)
458
+ top_p = kwargs.get("top_p", 0.9)
459
+ do_sample = kwargs.get("do_sample", temperature > 0.0)
460
+
461
+ result = self.pipeline(
462
+ full_prompt,
463
+ max_new_tokens=max_new_tokens,
464
+ temperature=temperature,
465
+ top_p=top_p,
466
+ do_sample=do_sample,
467
+ return_full_text=False,
468
+ pad_token_id=self.pipeline.tokenizer.eos_token_id,
469
+ truncation=True
470
+ )
471
+
472
+ output = {
473
+ "generated_text": result[0]["generated_text"],
474
+ "model": self.name,
475
+ "success": True
476
+ }
477
+
478
+ # === TRANSLATION ===
479
+ elif self.task == "translation":
480
+ src_lang = kwargs.get("source_lang", "eng_Latn")
481
+ tgt_lang = kwargs.get("target_lang", "spa_Latn")
482
+
483
+ result = self.pipeline(
484
+ input_data,
485
+ src_lang=src_lang,
486
+ tgt_lang=tgt_lang,
487
+ max_length=512
488
+ )
489
+
490
+ output = {
491
+ "translation": result[0]["translation_text"],
492
+ "source_lang": src_lang,
493
+ "target_lang": tgt_lang,
494
+ "model": self.name,
495
+ "success": True
496
+ }
497
+
498
+ # === SENTIMENT ANALYSIS ===
499
+ elif self.task == "sentiment-analysis":
500
+ result = self.pipeline(input_data)
501
+
502
+ output = {
503
+ "sentiment": result[0]["label"],
504
+ "confidence": result[0]["score"],
505
+ "model": self.name,
506
+ "success": True
507
+ }
508
+
509
+ # === BIAS DETECTION ===
510
+ elif self.task == "bias-detection":
511
+ candidate_labels = kwargs.get("candidate_labels", [
512
+ "neutral and objective",
513
+ "contains political bias",
514
+ "uses emotional language",
515
+ "culturally insensitive"
516
+ ])
517
+
518
+ result = self.pipeline(
519
+ input_data,
520
+ candidate_labels=candidate_labels,
521
+ multi_label=True
522
+ )
523
+
524
+ output = {
525
+ "labels": result["labels"],
526
+ "scores": result["scores"],
527
+ "model": self.name,
528
+ "success": True
529
+ }
530
+
531
+ # === TEXT CLASSIFICATION ===
532
+ elif self.task == "text-classification":
533
+ result = self.pipeline(input_data)
534
+
535
+ output = {
536
+ "label": result[0]["label"],
537
+ "confidence": result[0]["score"],
538
+ "model": self.name,
539
+ "success": True
540
+ }
541
+
542
+ else:
543
+ output = {
544
+ "error": f"Task '{self.task}' not implemented",
545
+ "model": self.name,
546
+ "success": False
547
+ }
548
+
549
+ # Track performance
550
+ inference_time = (datetime.now() - start_time).total_seconds() * 1000
551
+ self.metadata.inference_count += 1
552
+ self.metadata.total_inference_time_ms += inference_time
553
+ output["inference_time_ms"] = round(inference_time, 2)
554
+
555
+ return output
556
+
557
+ except Exception as e:
558
+ logger.error(f"โŒ Inference error in {self.name}: {e}", exc_info=True)
559
+ return {
560
+ "error": "Inference failed",
561
+ "detail": str(e),
562
+ "model": self.name,
563
+ "success": False
564
+ }
565
+
566
+ def unload(self) -> None:
567
+ """
568
+ ๐Ÿ—‘๏ธ Unloads the model to free memory.
569
+ Critical for Azure environments with limited resources.
570
+ """
571
+ if self.pipeline is not None:
572
+ logger.info(f"๐Ÿ—‘๏ธ Unloading {self.name}...")
573
+
574
+ # Delete pipeline
575
+ del self.pipeline
576
+ self.pipeline = None
577
+
578
+ # Remove from cache
579
+ if self.name in _MODEL_CACHE:
580
+ del _MODEL_CACHE[self.name]
581
+
582
+ # Force GPU memory release
583
+ if torch.cuda.is_available():
584
+ torch.cuda.empty_cache()
585
+
586
+ logger.info(f"โœ… {self.name} unloaded successfully")
587
+
588
+ # Log memory stats after unload
589
+ mem_stats = get_memory_stats()
590
+ if "gpu_allocated_gb" in mem_stats:
591
+ logger.info(f" GPU Memory: {mem_stats['gpu_allocated_gb']:.2f}GB remaining")
592
+
593
+ def get_metadata(self) -> Dict[str, Any]:
594
+ """
595
+ ๐Ÿ“Š Returns model metadata and performance stats.
596
+ """
597
+ return {
598
+ "name": self.metadata.name,
599
+ "task": self.metadata.task,
600
+ "model_name": self.metadata.model_name,
601
+ "device": self.metadata.device,
602
+ "loaded": self.pipeline is not None,
603
+ "loaded_at": self.metadata.loaded_at.isoformat() if self.metadata.loaded_at else None,
604
+ "load_time_seconds": self.metadata.load_time_seconds,
605
+ "memory_usage_gb": self.metadata.memory_usage_gb,
606
+ "inference_count": self.metadata.inference_count,
607
+ "avg_inference_time_ms": round(self.metadata.avg_inference_time_ms, 2)
608
+ }
609
+
610
+
611
+ # ============================================================
612
+ # MODEL LOADER (Singleton Manager)
613
+ # ============================================================
614
+
615
+ class ModelLoader:
616
+ """
617
+ ๐ŸŽ›๏ธ Singleton manager for all Penny's specialized models.
618
+
619
+ Features:
620
+ - Centralized model configuration
621
+ - Lazy loading (models only load when needed)
622
+ - Memory management
623
+ - Health monitoring
624
+ - Unified access interface
625
+ """
626
+
627
+ _instance: Optional['ModelLoader'] = None
628
+
629
+ def __new__(cls, *args, **kwargs):
630
+ """Singleton pattern - only one ModelLoader instance."""
631
+ if cls._instance is None:
632
+ cls._instance = super(ModelLoader, cls).__new__(cls)
633
+ return cls._instance
634
+
635
+ def __init__(self, config_path: Optional[str] = None):
636
+ """
637
+ Initialize ModelLoader (only runs once due to singleton).
638
+
639
+ Args:
640
+ config_path: Path to model_config.json (optional)
641
+ """
642
+ if not hasattr(self, '_models_loaded'):
643
+ self.models: Dict[str, ModelClient] = {}
644
+ self._models_loaded = True
645
+ self._initialization_time = datetime.now()
646
+
647
+ # Use provided path or default
648
+ config_file = Path(config_path) if config_path else CONFIG_PATH
649
+
650
+ try:
651
+ logger.info(f"๐Ÿ“– Loading model configuration from {config_file}")
652
+
653
+ if not config_file.exists():
654
+ logger.warning(f"โš ๏ธ Configuration file not found: {config_file}")
655
+ logger.info(" Create model_config.json with your model definitions")
656
+ return
657
+
658
+ with open(config_file, "r") as f:
659
+ config = json.load(f)
660
+
661
+ # Initialize ModelClients (doesn't load models yet)
662
+ for model_id, model_info in config.items():
663
+ self.models[model_id] = ModelClient(
664
+ name=model_id,
665
+ model_name=model_info["model_name"],
666
+ task=model_info["task"],
667
+ config=model_info.get("config", {})
668
+ )
669
+
670
+ logger.info(f"โœ… ModelLoader initialized with {len(self.models)} models:")
671
+ for model_id in self.models.keys():
672
+ logger.info(f" - {model_id}")
673
+
674
+ except json.JSONDecodeError as e:
675
+ logger.error(f"โŒ Invalid JSON in model_config.json: {e}")
676
+ except Exception as e:
677
+ logger.error(f"โŒ Failed to initialize ModelLoader: {e}", exc_info=True)
678
+
679
+ def get(self, model_id: str) -> Optional[ModelClient]:
680
+ """
681
+ ๐ŸŽฏ Retrieves a configured ModelClient by ID.
682
+
683
+ Args:
684
+ model_id: Model identifier from config
685
+
686
+ Returns:
687
+ ModelClient instance or None if not found
688
+ """
689
+ return self.models.get(model_id)
690
+
691
+ def list_models(self) -> List[str]:
692
+ """๐Ÿ“‹ Returns list of all available model IDs."""
693
+ return list(self.models.keys())
694
+
695
+ def get_loaded_models(self) -> List[str]:
696
+ """๐Ÿ“‹ Returns list of currently loaded model IDs."""
697
+ return [
698
+ model_id
699
+ for model_id, client in self.models.items()
700
+ if client.pipeline is not None
701
+ ]
702
+
703
+ def unload_all(self) -> None:
704
+ """
705
+ ๐Ÿ—‘๏ธ Unloads all models to free memory.
706
+ Useful for Azure environments when switching workloads.
707
+ """
708
+ logger.info("๐Ÿ—‘๏ธ Unloading all models...")
709
+ for model_client in self.models.values():
710
+ model_client.unload()
711
+ logger.info("โœ… All models unloaded")
712
+
713
+ def get_status(self) -> Dict[str, Any]:
714
+ """
715
+ ๐Ÿ“Š Returns comprehensive status of all models.
716
+ Useful for health checks and monitoring.
717
+ """
718
+ status = {
719
+ "initialization_time": self._initialization_time.isoformat(),
720
+ "total_models": len(self.models),
721
+ "loaded_models": len(self.get_loaded_models()),
722
+ "device": get_optimal_device(),
723
+ "memory": get_memory_stats(),
724
+ "models": {}
725
+ }
726
+
727
+ for model_id, client in self.models.items():
728
+ status["models"][model_id] = client.get_metadata()
729
+
730
+ return status
731
+
732
+
733
+ # ============================================================
734
+ # PUBLIC INTERFACE (Used by all *_utils.py modules)
735
+ # ============================================================
736
+
737
+ def load_model_pipeline(agent_name: str) -> Callable[..., Dict[str, Any]]:
738
+ """
739
+ ๐Ÿš€ Loads a model client and returns its inference function.
740
+
741
+ This is the main function used by other modules (translation_utils.py,
742
+ sentiment_utils.py, etc.) to access Penny's models.
743
+
744
+ Args:
745
+ agent_name: Model ID from model_config.json
746
+
747
+ Returns:
748
+ Callable inference function
749
+
750
+ Raises:
751
+ ValueError: If agent_name not found in configuration
752
+
753
+ Example:
754
+ >>> translator = load_model_pipeline("penny-translate-agent")
755
+ >>> result = translator("Hello world", target_lang="spa_Latn")
756
+ """
757
+ loader = ModelLoader()
758
+ client = loader.get(agent_name)
759
+
760
+ if client is None:
761
+ available = loader.list_models()
762
+ raise ValueError(
763
+ f"Agent ID '{agent_name}' not found in model configuration. "
764
+ f"Available models: {available}"
765
+ )
766
+
767
+ # Load the pipeline (lazy loading)
768
+ client.load_pipeline()
769
+
770
+ # Return a callable wrapper
771
+ def inference_wrapper(input_data, **kwargs):
772
+ return client.predict(input_data, **kwargs)
773
+
774
+ return inference_wrapper
775
+
776
+
777
+ # === CONVENIENCE FUNCTIONS ===
778
+
779
+ def get_model_status() -> Dict[str, Any]:
780
+ """
781
+ ๐Ÿ“Š Returns status of all configured models.
782
+ Useful for health checks and monitoring endpoints.
783
+ """
784
+ loader = ModelLoader()
785
+ return loader.get_status()
786
+
787
+
788
+ def preload_models(model_ids: Optional[List[str]] = None) -> None:
789
+ """
790
+ ๐Ÿš€ Preloads specified models during startup.
791
+
792
+ Args:
793
+ model_ids: List of model IDs to preload (None = all models)
794
+ """
795
+ loader = ModelLoader()
796
+
797
+ if model_ids is None:
798
+ model_ids = loader.list_models()
799
+
800
+ logger.info(f"๐Ÿš€ Preloading {len(model_ids)} models...")
801
+
802
+ for model_id in model_ids:
803
+ client = loader.get(model_id)
804
+ if client:
805
+ logger.info(f" Loading {model_id}...")
806
+ client.load_pipeline()
807
+
808
+ logger.info("โœ… Model preloading complete")
809
+
810
+
811
+ def initialize_model_system() -> bool:
812
+ """
813
+ ๐Ÿ Initializes the model system.
814
+ Should be called during app startup.
815
+
816
+ Returns:
817
+ True if initialization successful
818
+ """
819
+ logger.info("๐Ÿง  Initializing Penny's model system...")
820
+
821
+ try:
822
+ # Initialize singleton
823
+ loader = ModelLoader()
824
+
825
+ # Log device info
826
+ device = get_optimal_device()
827
+ mem_stats = get_memory_stats()
828
+
829
+ logger.info(f"โœ… Model system initialized")
830
+ logger.info(f"๐ŸŽฎ Compute device: {device}")
831
+
832
+ if "gpu_total_gb" in mem_stats:
833
+ logger.info(
834
+ f"๐Ÿ’พ GPU Memory: {mem_stats['gpu_total_gb']:.1f}GB total"
835
+ )
836
+
837
+ logger.info(f"๐Ÿ“ฆ {len(loader.models)} models configured")
838
+
839
+ # Optional: Preload critical models
840
+ # Uncomment to preload models at startup
841
+ # preload_models(["penny-core-agent"])
842
+
843
+ return True
844
+
845
+ except Exception as e:
846
+ logger.error(f"โŒ Failed to initialize model system: {e}", exc_info=True)
847
+ return False
848
+
849
+
850
+ # ============================================================
851
+ # CLI TESTING & DEBUGGING
852
+ # ============================================================
853
+
854
+ if __name__ == "__main__":
855
+ """
856
+ ๐Ÿงช Test script for model loading and inference.
857
+ Run with: python -m app.model_loader
858
+ """
859
+ print("=" * 60)
860
+ print("๐Ÿงช Testing Penny's Model System")
861
+ print("=" * 60)
862
+
863
+ # Initialize
864
+ loader = ModelLoader()
865
+ print(f"\n๐Ÿ“‹ Available models: {loader.list_models()}")
866
+
867
+ # Get status
868
+ status = get_model_status()
869
+ print(f"\n๐Ÿ“Š System status:")
870
+ print(json.dumps(status, indent=2, default=str))
871
+
872
+ # Test model loading (if models configured)
873
+ if loader.models:
874
+ test_model_id = list(loader.models.keys())[0]
875
+ print(f"\n๐Ÿงช Testing model: {test_model_id}")
876
+
877
+ client = loader.get(test_model_id)
878
+ if client:
879
+ print(f" Loading pipeline...")
880
+ success = client.load_pipeline()
881
+
882
+ if success:
883
+ print(f" โœ… Model loaded successfully!")
884
+ print(f" Metadata: {json.dumps(client.get_metadata(), indent=2, default=str)}")
885
+ else:
886
+ print(f" โŒ Model loading failed")