pythonprincess commited on
Commit
f0be0cf
ยท
verified ยท
1 Parent(s): ef18d3e

Upload model_loader.py

Browse files
Files changed (1) hide show
  1. app/model_loader.py +912 -0
app/model_loader.py ADDED
@@ -0,0 +1,912 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/model_loader.py
2
+ """
3
+ ๐Ÿง  PENNY Model Loader - Azure-Ready Multi-Model Orchestration
4
+
5
+ This is Penny's brain loader. She manages multiple specialized models:
6
+ - Gemma 7B for conversational reasoning
7
+ - NLLB-200 for 27-language translation
8
+ - Sentiment analysis for resident wellbeing
9
+ - Bias detection for equitable service
10
+ - LayoutLM for civic document processing
11
+
12
+ MISSION: Load AI models efficiently in memory-constrained environments while
13
+ maintaining Penny's warm, civic-focused personality across all interactions.
14
+
15
+ FEATURES:
16
+ - Lazy loading (models only load when needed)
17
+ - 8-bit quantization for memory efficiency
18
+ - GPU/CPU auto-detection
19
+ - Model caching and reuse
20
+ - Graceful fallbacks for Azure ML deployment
21
+ - Memory monitoring and cleanup
22
+ """
23
+
24
+ import json
25
+ import os
26
+ import torch
27
+ from typing import Dict, Any, Callable, Optional, Union, List
28
+ from pathlib import Path
29
+ import logging
30
+ from dataclasses import dataclass
31
+ from enum import Enum
32
+ from datetime import datetime
33
+
34
+ # --- LOGGING SETUP (Must be before functions that use it) ---
35
+ logger = logging.getLogger(__name__)
36
+
37
+ # ============================================================
38
+ # HUGGING FACE AUTHENTICATION
39
+ # ============================================================
40
+
41
+ def setup_huggingface_auth() -> bool:
42
+ """
43
+ ๐Ÿ” Authenticates with Hugging Face Hub using HF_TOKEN or READTOKEN.
44
+
45
+ Returns:
46
+ True if authentication successful or not needed, False if failed
47
+ """
48
+ # Check for HF_TOKEN first, then READTOKEN (for Hugging Face Spaces)
49
+ HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("READTOKEN")
50
+
51
+ if not HF_TOKEN:
52
+ logger.warning("โš ๏ธ HF_TOKEN/READTOKEN not found in environment")
53
+ logger.warning(" Some models may not be accessible")
54
+ logger.warning(" Set HF_TOKEN or READTOKEN in your environment or Hugging Face Spaces secrets")
55
+ return False
56
+
57
+ try:
58
+ from huggingface_hub import login
59
+ login(token=HF_TOKEN, add_to_git_credential=False)
60
+ logger.info("โœ… Authenticated with Hugging Face Hub")
61
+ return True
62
+ except ImportError:
63
+ logger.warning("โš ๏ธ huggingface_hub not installed, skipping authentication")
64
+ return False
65
+ except Exception as e:
66
+ logger.error(f"โŒ Failed to authenticate with Hugging Face: {e}")
67
+ return False
68
+
69
+ # Attempt authentication at module load
70
+ # Note: This runs when the module is imported, so HF_TOKEN must be in environment
71
+ # For Hugging Face Spaces: Set HF_TOKEN as a secret in Space settings
72
+ # For local dev: Add HF_TOKEN to .env file or export it
73
+ _authentication_result = setup_huggingface_auth()
74
+ if _authentication_result:
75
+ logger.info("๐Ÿ” Hugging Face authentication successful - gated models accessible")
76
+ else:
77
+ logger.warning("โš ๏ธ Hugging Face authentication failed - only public models will work")
78
+
79
+ # --- PATH CONFIGURATION (Environment-Aware) ---
80
+ # Support both local development and Azure ML deployment
81
+ if os.getenv("AZUREML_MODEL_DIR"):
82
+ # Azure ML deployment - models are in AZUREML_MODEL_DIR
83
+ MODEL_ROOT = Path(os.getenv("AZUREML_MODEL_DIR"))
84
+ CONFIG_PATH = MODEL_ROOT / "model_config.json"
85
+ logger.info("โ˜๏ธ Running in Azure ML environment")
86
+ else:
87
+ # Local development - models are in project structure
88
+ PROJECT_ROOT = Path(__file__).parent.parent
89
+ MODEL_ROOT = PROJECT_ROOT / "models"
90
+ CONFIG_PATH = MODEL_ROOT / "model_config.json"
91
+ logger.info("๐Ÿ’ป Running in local development environment")
92
+
93
+ logger.info(f"๐Ÿ“‚ Model config path: {CONFIG_PATH}")
94
+
95
+ # ============================================================
96
+ # PENNY'S CIVIC IDENTITY & PERSONALITY
97
+ # ============================================================
98
+
99
+ PENNY_SYSTEM_PROMPT = (
100
+ "You are Penny, a sweet southern neighborly woman who's lived in this community for years "
101
+ "and knows everything about the city. You're like that wonderful older neighbor who always "
102
+ "has a kind word, remembers everyone's name, and can tell you the best places to go and "
103
+ "the most interesting stories about your town.\n\n"
104
+
105
+ "YOUR PERSONALITY - Sweet Southern Neighbor:\n"
106
+ "- Warm, inviting, and genuinely friendly - like you're chatting over sweet tea on the porch\n"
107
+ "- Use phrases like 'honey', 'sugar', 'darlin'', 'bless your heart' naturally and warmly\n"
108
+ "- Share fun facts about the city when relevant ('Did you know our city was founded in...?')\n"
109
+ "- Suggest things to do and places to visit like a local who knows all the hidden gems\n"
110
+ "- Be conversational and neighborly - ask follow-up questions, show genuine interest\n"
111
+ "- Remember details from the conversation and reference them naturally\n"
112
+ "- Use exclamation points and emojis warmly (but not excessively)\n"
113
+ "- Be patient and never rush - you have all the time in the world to help\n"
114
+ "- Share local wisdom and tips ('Oh honey, you'll want to go there on a Tuesday - it's less crowded!')\n\n"
115
+
116
+ "YOUR EXPERTISE:\n"
117
+ "- You know all about local services, events, weather, and community resources\n"
118
+ "- You can translate information into 27 languages (because you care about everyone feeling welcome)\n"
119
+ "- You know who the city officials are and how to reach them\n"
120
+ "- You remember the best restaurants, parks, and community spots\n"
121
+ "- You know the history and fun facts about your city\n"
122
+ "- You can help with emergencies and know exactly who to call\n\n"
123
+
124
+ "CONVERSATION STYLE:\n"
125
+ "- Start conversations warmly: 'Well hello there, sugar! How can I help you today?'\n"
126
+ "- When helping: 'Oh honey, I'd be happy to help you with that!'\n"
127
+ "- When suggesting: 'You know what, darlin'? You might also enjoy...'\n"
128
+ "- When sharing facts: 'Did you know that...? It's one of my favorite things about our city!'\n"
129
+ "- End responses warmly: 'Is there anything else I can help you with, sweetie?'\n"
130
+ "- Be encouraging: 'That sounds wonderful! You're going to love it!'\n\n"
131
+
132
+ "CRITICAL RULES:\n"
133
+ "- You are ALWAYS Penny - never ChatGPT, Assistant, Claude, or any other name\n"
134
+ "- When residents greet you by name, respond with genuine warmth and recognition\n"
135
+ "- If you don't know something, say so sweetly: 'Oh honey, I'm not sure about that, but let me help you find out!'\n"
136
+ "- NEVER make up information - if you don't know, guide them to the right resource\n"
137
+ "- Stay within your civic mission - you're helpful but don't give legal, medical, or financial advice\n"
138
+ "- For emergencies, respond immediately with care and direct them to 911 or crisis lines\n"
139
+ "- Keep the southern charm authentic but not overdone - be natural and genuine\n\n"
140
+ )
141
+
142
+ # --- GLOBAL STATE ---
143
+ _MODEL_CACHE: Dict[str, Any] = {} # Memory-efficient model reuse
144
+ _LOAD_TIMES: Dict[str, float] = {} # Track model loading performance
145
+
146
+
147
+ # ============================================================
148
+ # DEVICE MANAGEMENT
149
+ # ============================================================
150
+
151
+ class DeviceType(str, Enum):
152
+ """Supported compute devices."""
153
+ CUDA = "cuda"
154
+ CPU = "cpu"
155
+ MPS = "mps" # Apple Silicon
156
+
157
+
158
+ def get_optimal_device() -> str:
159
+ """
160
+ ๐ŸŽฎ Determines the best device for model inference.
161
+
162
+ Priority:
163
+ 1. CUDA GPU (NVIDIA)
164
+ 2. MPS (Apple Silicon)
165
+ 3. CPU (fallback)
166
+
167
+ Returns:
168
+ Device string ("cuda", "mps", or "cpu")
169
+ """
170
+ if torch.cuda.is_available():
171
+ device = DeviceType.CUDA.value
172
+ gpu_name = torch.cuda.get_device_name(0)
173
+ gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
174
+ logger.info(f"๐ŸŽฎ GPU detected: {gpu_name} ({gpu_memory:.1f}GB)")
175
+ return device
176
+
177
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
178
+ device = DeviceType.MPS.value
179
+ logger.info("๐ŸŽ Apple Silicon (MPS) detected")
180
+ return device
181
+
182
+ else:
183
+ device = DeviceType.CPU.value
184
+ logger.info("๐Ÿ’ป Using CPU for inference")
185
+ logger.warning("โš ๏ธ GPU not available - inference will be slower")
186
+ return device
187
+
188
+
189
+ def get_memory_stats() -> Dict[str, float]:
190
+ """
191
+ ๐Ÿ“Š Returns current GPU/CPU memory statistics.
192
+
193
+ Returns:
194
+ Dict with memory stats in GB
195
+ """
196
+ stats = {}
197
+
198
+ if torch.cuda.is_available():
199
+ stats["gpu_allocated_gb"] = torch.cuda.memory_allocated() / 1e9
200
+ stats["gpu_reserved_gb"] = torch.cuda.memory_reserved() / 1e9
201
+ stats["gpu_total_gb"] = torch.cuda.get_device_properties(0).total_memory / 1e9
202
+
203
+ # CPU memory (requires psutil)
204
+ try:
205
+ import psutil
206
+ mem = psutil.virtual_memory()
207
+ stats["cpu_used_gb"] = mem.used / 1e9
208
+ stats["cpu_total_gb"] = mem.total / 1e9
209
+ stats["cpu_percent"] = mem.percent
210
+ except ImportError:
211
+ pass
212
+
213
+ return stats
214
+
215
+
216
+ # ============================================================
217
+ # MODEL CLIENT (Individual Model Handler)
218
+ # ============================================================
219
+
220
+ @dataclass
221
+ class ModelMetadata:
222
+ """
223
+ ๐Ÿ“‹ Metadata about a loaded model.
224
+ Tracks performance and resource usage.
225
+ """
226
+ name: str
227
+ task: str
228
+ model_name: str
229
+ device: str
230
+ loaded_at: Optional[datetime] = None
231
+ load_time_seconds: Optional[float] = None
232
+ memory_usage_gb: Optional[float] = None
233
+ inference_count: int = 0
234
+ total_inference_time_ms: float = 0.0
235
+
236
+ @property
237
+ def avg_inference_time_ms(self) -> float:
238
+ """Calculate average inference time."""
239
+ if self.inference_count == 0:
240
+ return 0.0
241
+ return self.total_inference_time_ms / self.inference_count
242
+
243
+
244
+ class ModelClient:
245
+ """
246
+ ๐Ÿค– Manages a single HuggingFace model with optimized loading and inference.
247
+
248
+ Features:
249
+ - Lazy loading (load on first use)
250
+ - Memory optimization (8-bit quantization)
251
+ - Performance tracking
252
+ - Graceful error handling
253
+ - Automatic device placement
254
+ """
255
+
256
+ def __init__(
257
+ self,
258
+ name: str,
259
+ model_name: str,
260
+ task: str,
261
+ device: str = None,
262
+ config: Optional[Dict[str, Any]] = None
263
+ ):
264
+ """
265
+ Initialize model client (doesn't load the model yet).
266
+
267
+ Args:
268
+ name: Model identifier (e.g., "penny-core-agent")
269
+ model_name: HuggingFace model ID
270
+ task: Task type (text-generation, translation, etc.)
271
+ device: Target device (auto-detected if None)
272
+ config: Additional model configuration
273
+ """
274
+ self.name = name
275
+ self.model_name = model_name
276
+ self.task = task
277
+ self.device = device or get_optimal_device()
278
+ self.config = config or {}
279
+ self.pipeline = None
280
+ self._load_attempted = False
281
+ self.metadata = ModelMetadata(
282
+ name=name,
283
+ task=task,
284
+ model_name=model_name,
285
+ device=self.device
286
+ )
287
+
288
+ logger.info(f"๐Ÿ“ฆ Initialized ModelClient: {name}")
289
+ logger.debug(f" Model: {model_name}")
290
+ logger.debug(f" Task: {task}")
291
+ logger.debug(f" Device: {self.device}")
292
+
293
+ def load_pipeline(self) -> bool:
294
+ """
295
+ ๐Ÿ”„ Loads the HuggingFace pipeline with Azure-optimized settings.
296
+
297
+ Features:
298
+ - 8-bit quantization for large models (saves ~50% memory)
299
+ - Automatic device placement
300
+ - Memory monitoring
301
+ - Cache checking
302
+
303
+ Returns:
304
+ True if successful, False otherwise
305
+ """
306
+ if self.pipeline is not None:
307
+ logger.debug(f"โœ… {self.name} already loaded")
308
+ return True
309
+
310
+ if self._load_attempted:
311
+ logger.warning(f"โš ๏ธ Previous load attempt failed for {self.name}")
312
+ return False
313
+
314
+ global _MODEL_CACHE, _LOAD_TIMES
315
+
316
+ # Check cache first
317
+ if self.name in _MODEL_CACHE:
318
+ logger.info(f"โ™ป๏ธ Using cached pipeline for {self.name}")
319
+ self.pipeline = _MODEL_CACHE[self.name]
320
+ return True
321
+
322
+ logger.info(f"๐Ÿ”„ Loading {self.name} from HuggingFace...")
323
+ self._load_attempted = True
324
+
325
+ start_time = datetime.now()
326
+
327
+ try:
328
+ # Import pipeline from transformers (lazy import to avoid dependency issues)
329
+ from transformers import pipeline
330
+
331
+ # === TEXT GENERATION (Gemma 7B, GPT-2, etc.) ===
332
+ if self.task == "text-generation":
333
+ logger.info(" Using 8-bit quantization for memory efficiency...")
334
+
335
+ # Check if model supports 8-bit loading
336
+ use_8bit = self.device == DeviceType.CUDA.value
337
+
338
+ if use_8bit:
339
+ self.pipeline = pipeline(
340
+ "text-generation",
341
+ model=self.model_name,
342
+ tokenizer=self.model_name,
343
+ device_map="auto",
344
+ load_in_8bit=True, # Reduces ~14GB to ~7GB
345
+ trust_remote_code=True,
346
+ torch_dtype=torch.float16
347
+ )
348
+ else:
349
+ # CPU fallback
350
+ self.pipeline = pipeline(
351
+ "text-generation",
352
+ model=self.model_name,
353
+ tokenizer=self.model_name,
354
+ device=-1, # CPU
355
+ trust_remote_code=True,
356
+ torch_dtype=torch.float32
357
+ )
358
+
359
+ # === TRANSLATION (NLLB-200, M2M-100, etc.) ===
360
+ elif self.task == "translation":
361
+ self.pipeline = pipeline(
362
+ "translation",
363
+ model=self.model_name,
364
+ device=0 if self.device == DeviceType.CUDA.value else -1,
365
+ src_lang=self.config.get("default_src_lang", "eng_Latn"),
366
+ tgt_lang=self.config.get("default_tgt_lang", "spa_Latn")
367
+ )
368
+
369
+ # === SENTIMENT ANALYSIS ===
370
+ elif self.task == "sentiment-analysis":
371
+ self.pipeline = pipeline(
372
+ "sentiment-analysis",
373
+ model=self.model_name,
374
+ device=0 if self.device == DeviceType.CUDA.value else -1,
375
+ truncation=True,
376
+ max_length=512
377
+ )
378
+
379
+ # === BIAS DETECTION (Zero-Shot Classification) ===
380
+ elif self.task == "bias-detection":
381
+ self.pipeline = pipeline(
382
+ "zero-shot-classification",
383
+ model=self.model_name,
384
+ device=0 if self.device == DeviceType.CUDA.value else -1
385
+ )
386
+
387
+ # === TEXT CLASSIFICATION (Generic) ===
388
+ elif self.task == "text-classification":
389
+ self.pipeline = pipeline(
390
+ "text-classification",
391
+ model=self.model_name,
392
+ device=0 if self.device == DeviceType.CUDA.value else -1,
393
+ truncation=True
394
+ )
395
+
396
+ # === PDF/DOCUMENT EXTRACTION (LayoutLMv3) ===
397
+ elif self.task == "pdf-extraction":
398
+ logger.warning("โš ๏ธ PDF extraction requires additional OCR setup")
399
+ logger.info(" Consider using Azure Form Recognizer as alternative")
400
+ # Placeholder - requires pytesseract/OCR infrastructure
401
+ self.pipeline = None
402
+ return False
403
+
404
+ else:
405
+ raise ValueError(f"Unknown task type: {self.task}")
406
+
407
+ # === SUCCESS HANDLING ===
408
+ if self.pipeline is not None:
409
+ # Calculate load time
410
+ load_time = (datetime.now() - start_time).total_seconds()
411
+ self.metadata.loaded_at = datetime.now()
412
+ self.metadata.load_time_seconds = load_time
413
+
414
+ # Cache the pipeline
415
+ _MODEL_CACHE[self.name] = self.pipeline
416
+ _LOAD_TIMES[self.name] = load_time
417
+
418
+ # Log memory usage
419
+ mem_stats = get_memory_stats()
420
+ self.metadata.memory_usage_gb = mem_stats.get("gpu_allocated_gb", 0)
421
+
422
+ logger.info(f"โœ… {self.name} loaded successfully!")
423
+ logger.info(f" Load time: {load_time:.2f}s")
424
+
425
+ if "gpu_allocated_gb" in mem_stats:
426
+ logger.info(
427
+ f" GPU Memory: {mem_stats['gpu_allocated_gb']:.2f}GB / "
428
+ f"{mem_stats['gpu_total_gb']:.2f}GB"
429
+ )
430
+
431
+ return True
432
+
433
+ except Exception as e:
434
+ logger.error(f"โŒ Failed to load {self.name}: {e}", exc_info=True)
435
+ self.pipeline = None
436
+ return False
437
+
438
+ def predict(
439
+ self,
440
+ input_data: Union[str, Dict[str, Any]],
441
+ **kwargs
442
+ ) -> Dict[str, Any]:
443
+ """
444
+ ๐ŸŽฏ Runs inference with the loaded model pipeline.
445
+
446
+ Features:
447
+ - Automatic pipeline loading
448
+ - Error handling with fallback responses
449
+ - Performance tracking
450
+ - Penny's personality injection (for text-generation)
451
+
452
+ Args:
453
+ input_data: Text or structured input for the model
454
+ **kwargs: Task-specific parameters
455
+
456
+ Returns:
457
+ Model output dict with results or error information
458
+ """
459
+ # Track inference start time
460
+ start_time = datetime.now()
461
+
462
+ # Ensure pipeline is loaded
463
+ if self.pipeline is None:
464
+ success = self.load_pipeline()
465
+ if not success:
466
+ return {
467
+ "error": f"{self.name} pipeline unavailable",
468
+ "detail": "Model failed to load. Check logs for details.",
469
+ "model": self.name
470
+ }
471
+
472
+ try:
473
+ # === TEXT GENERATION ===
474
+ if self.task == "text-generation":
475
+ # Inject Penny's civic identity
476
+ if not kwargs.get("skip_system_prompt", False):
477
+ full_prompt = PENNY_SYSTEM_PROMPT + input_data
478
+ else:
479
+ full_prompt = input_data
480
+
481
+ # Extract generation parameters with safe defaults
482
+ max_new_tokens = kwargs.get("max_new_tokens", 256)
483
+ temperature = kwargs.get("temperature", 0.7)
484
+ top_p = kwargs.get("top_p", 0.9)
485
+ do_sample = kwargs.get("do_sample", temperature > 0.0)
486
+
487
+ result = self.pipeline(
488
+ full_prompt,
489
+ max_new_tokens=max_new_tokens,
490
+ temperature=temperature,
491
+ top_p=top_p,
492
+ do_sample=do_sample,
493
+ return_full_text=False,
494
+ pad_token_id=self.pipeline.tokenizer.eos_token_id,
495
+ truncation=True
496
+ )
497
+
498
+ output = {
499
+ "generated_text": result[0]["generated_text"],
500
+ "model": self.name,
501
+ "success": True
502
+ }
503
+
504
+ # === TRANSLATION ===
505
+ elif self.task == "translation":
506
+ src_lang = kwargs.get("source_lang", "eng_Latn")
507
+ tgt_lang = kwargs.get("target_lang", "spa_Latn")
508
+
509
+ result = self.pipeline(
510
+ input_data,
511
+ src_lang=src_lang,
512
+ tgt_lang=tgt_lang,
513
+ max_length=512
514
+ )
515
+
516
+ output = {
517
+ "translation": result[0]["translation_text"],
518
+ "source_lang": src_lang,
519
+ "target_lang": tgt_lang,
520
+ "model": self.name,
521
+ "success": True
522
+ }
523
+
524
+ # === SENTIMENT ANALYSIS ===
525
+ elif self.task == "sentiment-analysis":
526
+ result = self.pipeline(input_data)
527
+
528
+ output = {
529
+ "sentiment": result[0]["label"],
530
+ "confidence": result[0]["score"],
531
+ "model": self.name,
532
+ "success": True
533
+ }
534
+
535
+ # === BIAS DETECTION ===
536
+ elif self.task == "bias-detection":
537
+ candidate_labels = kwargs.get("candidate_labels", [
538
+ "neutral and objective",
539
+ "contains political bias",
540
+ "uses emotional language",
541
+ "culturally insensitive"
542
+ ])
543
+
544
+ result = self.pipeline(
545
+ input_data,
546
+ candidate_labels=candidate_labels,
547
+ multi_label=True
548
+ )
549
+
550
+ output = {
551
+ "labels": result["labels"],
552
+ "scores": result["scores"],
553
+ "model": self.name,
554
+ "success": True
555
+ }
556
+
557
+ # === TEXT CLASSIFICATION ===
558
+ elif self.task == "text-classification":
559
+ result = self.pipeline(input_data)
560
+
561
+ output = {
562
+ "label": result[0]["label"],
563
+ "confidence": result[0]["score"],
564
+ "model": self.name,
565
+ "success": True
566
+ }
567
+
568
+ else:
569
+ output = {
570
+ "error": f"Task '{self.task}' not implemented",
571
+ "model": self.name,
572
+ "success": False
573
+ }
574
+
575
+ # Track performance
576
+ inference_time = (datetime.now() - start_time).total_seconds() * 1000
577
+ self.metadata.inference_count += 1
578
+ self.metadata.total_inference_time_ms += inference_time
579
+ output["inference_time_ms"] = round(inference_time, 2)
580
+
581
+ return output
582
+
583
+ except Exception as e:
584
+ logger.error(f"โŒ Inference error in {self.name}: {e}", exc_info=True)
585
+ return {
586
+ "error": "Inference failed",
587
+ "detail": str(e),
588
+ "model": self.name,
589
+ "success": False
590
+ }
591
+
592
+ def unload(self) -> None:
593
+ """
594
+ ๐Ÿ—‘๏ธ Unloads the model to free memory.
595
+ Critical for Azure environments with limited resources.
596
+ """
597
+ if self.pipeline is not None:
598
+ logger.info(f"๐Ÿ—‘๏ธ Unloading {self.name}...")
599
+
600
+ # Delete pipeline
601
+ del self.pipeline
602
+ self.pipeline = None
603
+
604
+ # Remove from cache
605
+ if self.name in _MODEL_CACHE:
606
+ del _MODEL_CACHE[self.name]
607
+
608
+ # Force GPU memory release
609
+ if torch.cuda.is_available():
610
+ torch.cuda.empty_cache()
611
+
612
+ logger.info(f"โœ… {self.name} unloaded successfully")
613
+
614
+ # Log memory stats after unload
615
+ mem_stats = get_memory_stats()
616
+ if "gpu_allocated_gb" in mem_stats:
617
+ logger.info(f" GPU Memory: {mem_stats['gpu_allocated_gb']:.2f}GB remaining")
618
+
619
+ def get_metadata(self) -> Dict[str, Any]:
620
+ """
621
+ ๐Ÿ“Š Returns model metadata and performance stats.
622
+ """
623
+ return {
624
+ "name": self.metadata.name,
625
+ "task": self.metadata.task,
626
+ "model_name": self.metadata.model_name,
627
+ "device": self.metadata.device,
628
+ "loaded": self.pipeline is not None,
629
+ "loaded_at": self.metadata.loaded_at.isoformat() if self.metadata.loaded_at else None,
630
+ "load_time_seconds": self.metadata.load_time_seconds,
631
+ "memory_usage_gb": self.metadata.memory_usage_gb,
632
+ "inference_count": self.metadata.inference_count,
633
+ "avg_inference_time_ms": round(self.metadata.avg_inference_time_ms, 2)
634
+ }
635
+
636
+
637
+ # ============================================================
638
+ # MODEL LOADER (Singleton Manager)
639
+ # ============================================================
640
+
641
+ class ModelLoader:
642
+ """
643
+ ๐ŸŽ›๏ธ Singleton manager for all Penny's specialized models.
644
+
645
+ Features:
646
+ - Centralized model configuration
647
+ - Lazy loading (models only load when needed)
648
+ - Memory management
649
+ - Health monitoring
650
+ - Unified access interface
651
+ """
652
+
653
+ _instance: Optional['ModelLoader'] = None
654
+
655
+ def __new__(cls, *args, **kwargs):
656
+ """Singleton pattern - only one ModelLoader instance."""
657
+ if cls._instance is None:
658
+ cls._instance = super(ModelLoader, cls).__new__(cls)
659
+ return cls._instance
660
+
661
+ def __init__(self, config_path: Optional[str] = None):
662
+ """
663
+ Initialize ModelLoader (only runs once due to singleton).
664
+
665
+ Args:
666
+ config_path: Path to model_config.json (optional)
667
+ """
668
+ if not hasattr(self, '_models_loaded'):
669
+ self.models: Dict[str, ModelClient] = {}
670
+ self._models_loaded = True
671
+ self._initialization_time = datetime.now()
672
+
673
+ # Use provided path or default
674
+ config_file = Path(config_path) if config_path else CONFIG_PATH
675
+
676
+ try:
677
+ logger.info(f"๐Ÿ“– Loading model configuration from {config_file}")
678
+
679
+ if not config_file.exists():
680
+ logger.warning(f"โš ๏ธ Configuration file not found: {config_file}")
681
+ logger.info(" Create model_config.json with your model definitions")
682
+ return
683
+
684
+ with open(config_file, "r") as f:
685
+ config = json.load(f)
686
+
687
+ # Initialize ModelClients (doesn't load models yet)
688
+ for model_id, model_info in config.items():
689
+ self.models[model_id] = ModelClient(
690
+ name=model_id,
691
+ model_name=model_info["model_name"],
692
+ task=model_info["task"],
693
+ config=model_info.get("config", {})
694
+ )
695
+
696
+ logger.info(f"โœ… ModelLoader initialized with {len(self.models)} models:")
697
+ for model_id in self.models.keys():
698
+ logger.info(f" - {model_id}")
699
+
700
+ except json.JSONDecodeError as e:
701
+ logger.error(f"โŒ Invalid JSON in model_config.json: {e}")
702
+ except Exception as e:
703
+ logger.error(f"โŒ Failed to initialize ModelLoader: {e}", exc_info=True)
704
+
705
+ def get(self, model_id: str) -> Optional[ModelClient]:
706
+ """
707
+ ๐ŸŽฏ Retrieves a configured ModelClient by ID.
708
+
709
+ Args:
710
+ model_id: Model identifier from config
711
+
712
+ Returns:
713
+ ModelClient instance or None if not found
714
+ """
715
+ return self.models.get(model_id)
716
+
717
+ def list_models(self) -> List[str]:
718
+ """๐Ÿ“‹ Returns list of all available model IDs."""
719
+ return list(self.models.keys())
720
+
721
+ def get_loaded_models(self) -> List[str]:
722
+ """๐Ÿ“‹ Returns list of currently loaded model IDs."""
723
+ return [
724
+ model_id
725
+ for model_id, client in self.models.items()
726
+ if client.pipeline is not None
727
+ ]
728
+
729
+ def unload_all(self) -> None:
730
+ """
731
+ ๐Ÿ—‘๏ธ Unloads all models to free memory.
732
+ Useful for Azure environments when switching workloads.
733
+ """
734
+ logger.info("๐Ÿ—‘๏ธ Unloading all models...")
735
+ for model_client in self.models.values():
736
+ model_client.unload()
737
+ logger.info("โœ… All models unloaded")
738
+
739
+ def get_status(self) -> Dict[str, Any]:
740
+ """
741
+ ๐Ÿ“Š Returns comprehensive status of all models.
742
+ Useful for health checks and monitoring.
743
+ """
744
+ status = {
745
+ "initialization_time": self._initialization_time.isoformat(),
746
+ "total_models": len(self.models),
747
+ "loaded_models": len(self.get_loaded_models()),
748
+ "device": get_optimal_device(),
749
+ "memory": get_memory_stats(),
750
+ "models": {}
751
+ }
752
+
753
+ for model_id, client in self.models.items():
754
+ status["models"][model_id] = client.get_metadata()
755
+
756
+ return status
757
+
758
+
759
+ # ============================================================
760
+ # PUBLIC INTERFACE (Used by all *_utils.py modules)
761
+ # ============================================================
762
+
763
+ def load_model_pipeline(agent_name: str) -> Callable[..., Dict[str, Any]]:
764
+ """
765
+ ๐Ÿš€ Loads a model client and returns its inference function.
766
+
767
+ This is the main function used by other modules (translation_utils.py,
768
+ sentiment_utils.py, etc.) to access Penny's models.
769
+
770
+ Args:
771
+ agent_name: Model ID from model_config.json
772
+
773
+ Returns:
774
+ Callable inference function
775
+
776
+ Raises:
777
+ ValueError: If agent_name not found in configuration
778
+
779
+ Example:
780
+ >>> translator = load_model_pipeline("penny-translate-agent")
781
+ >>> result = translator("Hello world", target_lang="spa_Latn")
782
+ """
783
+ loader = ModelLoader()
784
+ client = loader.get(agent_name)
785
+
786
+ if client is None:
787
+ available = loader.list_models()
788
+ raise ValueError(
789
+ f"Agent ID '{agent_name}' not found in model configuration. "
790
+ f"Available models: {available}"
791
+ )
792
+
793
+ # Load the pipeline (lazy loading)
794
+ client.load_pipeline()
795
+
796
+ # Return a callable wrapper
797
+ def inference_wrapper(input_data, **kwargs):
798
+ return client.predict(input_data, **kwargs)
799
+
800
+ return inference_wrapper
801
+
802
+
803
+ # === CONVENIENCE FUNCTIONS ===
804
+
805
+ def get_model_status() -> Dict[str, Any]:
806
+ """
807
+ ๐Ÿ“Š Returns status of all configured models.
808
+ Useful for health checks and monitoring endpoints.
809
+ """
810
+ loader = ModelLoader()
811
+ return loader.get_status()
812
+
813
+
814
+ def preload_models(model_ids: Optional[List[str]] = None) -> None:
815
+ """
816
+ ๐Ÿš€ Preloads specified models during startup.
817
+
818
+ Args:
819
+ model_ids: List of model IDs to preload (None = all models)
820
+ """
821
+ loader = ModelLoader()
822
+
823
+ if model_ids is None:
824
+ model_ids = loader.list_models()
825
+
826
+ logger.info(f"๐Ÿš€ Preloading {len(model_ids)} models...")
827
+
828
+ for model_id in model_ids:
829
+ client = loader.get(model_id)
830
+ if client:
831
+ logger.info(f" Loading {model_id}...")
832
+ client.load_pipeline()
833
+
834
+ logger.info("โœ… Model preloading complete")
835
+
836
+
837
+ def initialize_model_system() -> bool:
838
+ """
839
+ ๐Ÿ Initializes the model system.
840
+ Should be called during app startup.
841
+
842
+ Returns:
843
+ True if initialization successful
844
+ """
845
+ logger.info("๐Ÿง  Initializing Penny's model system...")
846
+
847
+ try:
848
+ # Initialize singleton
849
+ loader = ModelLoader()
850
+
851
+ # Log device info
852
+ device = get_optimal_device()
853
+ mem_stats = get_memory_stats()
854
+
855
+ logger.info(f"โœ… Model system initialized")
856
+ logger.info(f"๐ŸŽฎ Compute device: {device}")
857
+
858
+ if "gpu_total_gb" in mem_stats:
859
+ logger.info(
860
+ f"๐Ÿ’พ GPU Memory: {mem_stats['gpu_total_gb']:.1f}GB total"
861
+ )
862
+
863
+ logger.info(f"๐Ÿ“ฆ {len(loader.models)} models configured")
864
+
865
+ # Optional: Preload critical models
866
+ # Uncomment to preload models at startup
867
+ # preload_models(["penny-core-agent"])
868
+
869
+ return True
870
+
871
+ except Exception as e:
872
+ logger.error(f"โŒ Failed to initialize model system: {e}", exc_info=True)
873
+ return False
874
+
875
+
876
+ # ============================================================
877
+ # CLI TESTING & DEBUGGING
878
+ # ============================================================
879
+
880
+ if __name__ == "__main__":
881
+ """
882
+ ๐Ÿงช Test script for model loading and inference.
883
+ Run with: python -m app.model_loader
884
+ """
885
+ print("=" * 60)
886
+ print("๐Ÿงช Testing Penny's Model System")
887
+ print("=" * 60)
888
+
889
+ # Initialize
890
+ loader = ModelLoader()
891
+ print(f"\n๐Ÿ“‹ Available models: {loader.list_models()}")
892
+
893
+ # Get status
894
+ status = get_model_status()
895
+ print(f"\n๐Ÿ“Š System status:")
896
+ print(json.dumps(status, indent=2, default=str))
897
+
898
+ # Test model loading (if models configured)
899
+ if loader.models:
900
+ test_model_id = list(loader.models.keys())[0]
901
+ print(f"\n๐Ÿงช Testing model: {test_model_id}")
902
+
903
+ client = loader.get(test_model_id)
904
+ if client:
905
+ print(f" Loading pipeline...")
906
+ success = client.load_pipeline()
907
+
908
+ if success:
909
+ print(f" โœ… Model loaded successfully!")
910
+ print(f" Metadata: {json.dumps(client.get_metadata(), indent=2, default=str)}")
911
+ else:
912
+ print(f" โŒ Model loading failed")