Malaji71 commited on
Commit
8d6efc2
·
verified ·
1 Parent(s): 6715d2b

Update models.py

Browse files
Files changed (1) hide show
  1. models.py +242 -207
models.py CHANGED
@@ -1,17 +1,22 @@
1
  """
2
- Model management for FLUX Prompt Optimizer
3
- Handles Florence-2 and Bagel model integration
4
  """
5
 
6
  import logging
7
- import requests
 
8
  import spaces
9
  import torch
10
  from typing import Optional, Dict, Any, Tuple
11
  from PIL import Image
12
- from transformers import AutoProcessor, AutoModelForCausalLM
 
13
 
14
- from config import MODEL_CONFIG, get_device_config
 
 
 
15
  from utils import clean_memory, safe_execute
16
 
17
  logger = logging.getLogger(__name__)
@@ -22,9 +27,8 @@ class BaseImageAnalyzer:
22
 
23
  def __init__(self):
24
  self.model = None
25
- self.processor = None
26
- self.device_config = get_device_config()
27
  self.is_initialized = False
 
28
 
29
  def initialize(self) -> bool:
30
  """Initialize the model"""
@@ -36,265 +40,284 @@ class BaseImageAnalyzer:
36
 
37
  def cleanup(self) -> None:
38
  """Clean up model resources"""
39
- if self.model is not None:
40
  del self.model
41
  self.model = None
42
- if self.processor is not None:
43
- del self.processor
44
- self.processor = None
45
  clean_memory()
46
 
47
 
48
- class Florence2Analyzer(BaseImageAnalyzer):
49
- """Florence-2 model for image analysis"""
50
 
51
  def __init__(self):
52
  super().__init__()
53
- self.config = MODEL_CONFIG["florence2"]
 
 
 
 
 
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  def initialize(self) -> bool:
56
- """Initialize Florence-2 model"""
57
  if self.is_initialized:
58
  return True
59
-
60
  try:
61
- logger.info("Initializing Florence-2 model...")
 
 
 
62
 
63
- model_id = self.config["model_id"]
64
 
65
- # Load processor
66
- self.processor = AutoProcessor.from_pretrained(
67
- model_id,
68
- trust_remote_code=self.config["trust_remote_code"]
 
 
 
 
 
69
  )
 
 
 
70
 
71
- # Load model
72
- self.model = AutoModelForCausalLM.from_pretrained(
73
- model_id,
74
- trust_remote_code=self.config["trust_remote_code"],
75
- torch_dtype=self.config["torch_dtype"] if self.device_config["use_gpu"] else torch.float32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  )
77
 
78
- # Move to appropriate device
79
- if self.device_config["use_gpu"]:
80
- self.model = self.model.to(self.device_config["device"])
81
- else:
82
- self.model = self.model.to("cpu")
83
-
84
- self.model.eval()
85
- self.is_initialized = True
86
 
87
- logger.info(f"Florence-2 initialized on {self.device_config['device']}")
88
- return True
 
89
 
90
- except Exception as e:
91
- logger.error(f"Florence-2 initialization failed: {e}")
92
- self.cleanup()
93
- return False
94
-
95
- @spaces.GPU(duration=60)
96
- def _gpu_inference(self, image: Image.Image, task_prompt: str) -> str:
97
- """Run inference on GPU with spaces decorator"""
98
- try:
99
- # Move model to GPU for inference
100
- if self.device_config["use_gpu"]:
101
- self.model = self.model.to("cuda")
102
 
103
- # Prepare inputs
104
- inputs = self.processor(text=task_prompt, images=image, return_tensors="pt")
 
 
 
 
105
 
106
- # Move inputs to device
107
- device = "cuda" if self.device_config["use_gpu"] else self.device_config["device"]
108
- inputs = {k: v.to(device) for k, v in inputs.items()}
109
 
110
- # Generate response
111
- with torch.no_grad():
112
- if self.device_config["use_gpu"]:
113
- with torch.cuda.amp.autocast(dtype=torch.float16):
114
- generated_ids = self.model.generate(
115
- input_ids=inputs["input_ids"],
116
- pixel_values=inputs["pixel_values"],
117
- max_new_tokens=self.config["max_new_tokens"],
118
- num_beams=3,
119
- do_sample=False
120
- )
121
- else:
122
- generated_ids = self.model.generate(
123
- input_ids=inputs["input_ids"],
124
- pixel_values=inputs["pixel_values"],
125
- max_new_tokens=self.config["max_new_tokens"],
126
- num_beams=3,
127
- do_sample=False
128
- )
129
 
130
- # Decode response
131
- generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
132
- parsed = self.processor.post_process_generation(
133
- generated_text,
134
- task=task_prompt,
135
- image_size=(image.width, image.height)
 
 
136
  )
137
 
138
- # Extract caption
139
- if task_prompt in parsed:
140
- return parsed[task_prompt]
141
- else:
142
- return str(parsed) if parsed else ""
143
-
144
  except Exception as e:
145
- logger.error(f"Florence-2 GPU inference failed: {e}")
146
- return ""
147
- finally:
148
- # Move model back to CPU to free GPU memory
149
- if self.device_config["use_gpu"]:
150
- self.model = self.model.to("cpu")
151
- clean_memory()
152
-
153
- def analyze_image(self, image: Image.Image) -> Tuple[str, Dict[str, Any]]:
154
- """Analyze image using Florence-2"""
155
  if not self.is_initialized:
156
  success = self.initialize()
157
  if not success:
158
- return "Model initialization failed", {"error": "Florence-2 not available"}
159
 
160
  try:
161
- # Define analysis tasks
162
- tasks = {
163
- "detailed": "<DETAILED_CAPTION>",
164
- "more_detailed": "<MORE_DETAILED_CAPTION>",
165
- "caption": "<CAPTION>"
166
- }
167
 
168
- results = {}
 
 
169
 
170
- # Run analysis for each task
171
- for task_name, task_prompt in tasks.items():
172
- if self.device_config["use_gpu"]:
173
- result = self._gpu_inference(image, task_prompt)
174
- else:
175
- result = self._cpu_inference(image, task_prompt)
176
- results[task_name] = result
177
 
178
- # Choose best result
179
- if results["more_detailed"]:
180
- main_description = results["more_detailed"]
181
- elif results["detailed"]:
182
- main_description = results["detailed"]
183
- else:
184
- main_description = results["caption"] or "A photograph"
 
 
185
 
186
  # Prepare metadata
187
  metadata = {
188
- "model": "Florence-2",
189
  "device": self.device_config["device"],
190
- "all_results": results,
191
- "confidence": 0.85 # Florence-2 generally reliable
 
 
192
  }
193
 
194
- logger.info(f"Florence-2 analysis complete: {len(main_description)} chars")
195
- return main_description, metadata
196
 
197
  except Exception as e:
198
- logger.error(f"Florence-2 analysis failed: {e}")
199
- return "Analysis failed", {"error": str(e)}
200
 
201
- def _cpu_inference(self, image: Image.Image, task_prompt: str) -> str:
202
- """Run inference on CPU"""
203
  try:
204
- inputs = self.processor(text=task_prompt, images=image, return_tensors="pt")
205
-
206
- with torch.no_grad():
207
- generated_ids = self.model.generate(
208
- input_ids=inputs["input_ids"],
209
- pixel_values=inputs["pixel_values"],
210
- max_new_tokens=self.config["max_new_tokens"],
211
- num_beams=2, # Reduced for CPU
212
- do_sample=False
213
- )
214
-
215
- generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
216
- parsed = self.processor.post_process_generation(
217
- generated_text,
218
- task=task_prompt,
219
- image_size=(image.width, image.height)
220
- )
221
 
222
- if task_prompt in parsed:
223
- return parsed[task_prompt]
224
- else:
225
- return str(parsed) if parsed else ""
226
 
 
 
227
  except Exception as e:
228
- logger.error(f"Florence-2 CPU inference failed: {e}")
229
- return ""
230
 
231
 
232
- class BagelAnalyzer(BaseImageAnalyzer):
233
- """Bagel-7B model analyzer via API"""
234
 
235
  def __init__(self):
236
  super().__init__()
237
- self.config = MODEL_CONFIG["bagel"]
238
- self.session = requests.Session()
239
 
240
  def initialize(self) -> bool:
241
- """Initialize Bagel analyzer (API-based)"""
242
- try:
243
- # Test API connectivity
244
- test_response = self.session.get(
245
- self.config["api_url"],
246
- timeout=self.config["timeout"]
247
- )
248
-
249
- if test_response.status_code == 200:
250
- self.is_initialized = True
251
- logger.info("Bagel API connection established")
252
- return True
253
- else:
254
- logger.error(f"Bagel API not accessible: {test_response.status_code}")
255
- return False
256
-
257
- except Exception as e:
258
- logger.error(f"Bagel initialization failed: {e}")
259
- return False
260
 
261
  def analyze_image(self, image: Image.Image) -> Tuple[str, Dict[str, Any]]:
262
- """Analyze image using Bagel-7B API"""
263
- if not self.is_initialized:
264
- success = self.initialize()
265
- if not success:
266
- return "Bagel API not available", {"error": "API connection failed"}
267
-
268
  try:
269
- # Convert image to base64 or prepare for API call
270
- # Note: This is a placeholder - actual implementation would depend on Bagel API format
 
 
 
 
 
 
 
 
 
 
 
271
 
272
- # For now, return a placeholder response
273
- # In real implementation, you would:
274
- # 1. Convert image to required format
275
- # 2. Make API call to Bagel endpoint
276
- # 3. Parse response
277
 
278
- description = "Detailed image analysis via Bagel-7B (API implementation needed)"
279
  metadata = {
280
- "model": "Bagel-7B",
281
- "method": "API",
282
- "confidence": 0.8
 
 
 
283
  }
284
 
285
- logger.info("Bagel analysis complete (placeholder)")
286
  return description, metadata
287
 
288
  except Exception as e:
289
- logger.error(f"Bagel analysis failed: {e}")
290
- return "Analysis failed", {"error": str(e)}
291
 
292
 
293
  class ModelManager:
294
- """Manager for handling multiple analysis models"""
295
 
296
- def __init__(self, preferred_model: str = None):
297
- self.preferred_model = preferred_model or MODEL_CONFIG["primary_model"]
298
  self.analyzers = {}
299
  self.current_analyzer = None
300
 
@@ -303,27 +326,38 @@ class ModelManager:
303
  model_name = model_name or self.preferred_model
304
 
305
  if model_name not in self.analyzers:
306
- if model_name == "florence2":
307
- self.analyzers[model_name] = Florence2Analyzer()
308
- elif model_name == "bagel":
309
  self.analyzers[model_name] = BagelAnalyzer()
 
 
310
  else:
311
- logger.error(f"Unknown model: {model_name}")
312
- return None
 
313
 
314
  return self.analyzers[model_name]
315
 
316
  def analyze_image(self, image: Image.Image, model_name: str = None) -> Tuple[str, Dict[str, Any]]:
317
  """Analyze image with specified or preferred model"""
 
318
  analyzer = self.get_analyzer(model_name)
319
  if analyzer is None:
320
  return "No analyzer available", {"error": "Model not found"}
321
 
322
  success, result = safe_execute(analyzer.analyze_image, image)
323
- if success:
 
324
  return result
325
  else:
326
- return "Analysis failed", {"error": result}
 
 
 
 
 
 
 
 
327
 
328
  def cleanup_all(self) -> None:
329
  """Clean up all model resources"""
@@ -331,19 +365,20 @@ class ModelManager:
331
  analyzer.cleanup()
332
  self.analyzers.clear()
333
  clean_memory()
 
334
 
335
 
336
  # Global model manager instance
337
- model_manager = ModelManager()
338
 
339
 
340
  def analyze_image(image: Image.Image, model_name: str = None) -> Tuple[str, Dict[str, Any]]:
341
  """
342
- Convenience function for image analysis
343
 
344
  Args:
345
  image: PIL Image to analyze
346
- model_name: Optional model name ("florence2" or "bagel")
347
 
348
  Returns:
349
  Tuple of (description, metadata)
@@ -354,8 +389,8 @@ def analyze_image(image: Image.Image, model_name: str = None) -> Tuple[str, Dict
354
  # Export main components
355
  __all__ = [
356
  "BaseImageAnalyzer",
357
- "Florence2Analyzer",
358
- "BagelAnalyzer",
359
  "ModelManager",
360
  "model_manager",
361
  "analyze_image"
 
1
  """
2
+ Model management for Frame 0 Laboratory for MIA
3
+ BAGEL 7B integration for advanced image analysis
4
  """
5
 
6
  import logging
7
+ import os
8
+ import subprocess
9
  import spaces
10
  import torch
11
  from typing import Optional, Dict, Any, Tuple
12
  from PIL import Image
13
+ from huggingface_hub import snapshot_download
14
+ from accelerate import infer_auto_device_map, load_checkpoint_and_dispatch, init_empty_weights
15
 
16
+ from config import (
17
+ BAGEL_CONFIG, get_device_config, get_bagel_device_map,
18
+ BAGEL_PROMPTS, FLASH_ATTN_INSTALL
19
+ )
20
  from utils import clean_memory, safe_execute
21
 
22
  logger = logging.getLogger(__name__)
 
27
 
28
  def __init__(self):
29
  self.model = None
 
 
30
  self.is_initialized = False
31
+ self.device_config = get_device_config()
32
 
33
  def initialize(self) -> bool:
34
  """Initialize the model"""
 
40
 
41
  def cleanup(self) -> None:
42
  """Clean up model resources"""
43
+ if hasattr(self, 'model') and self.model is not None:
44
  del self.model
45
  self.model = None
 
 
 
46
  clean_memory()
47
 
48
 
49
+ class BagelAnalyzer(BaseImageAnalyzer):
50
+ """BAGEL 7B model for advanced image analysis"""
51
 
52
  def __init__(self):
53
  super().__init__()
54
+ self.inferencer = None
55
+ self.tokenizer = None
56
+ self.vae_model = None
57
+ self.vae_transform = None
58
+ self.vit_transform = None
59
+ self._install_flash_attn()
60
 
61
+ def _install_flash_attn(self):
62
+ """Install flash attention dynamically"""
63
+ try:
64
+ logger.info("Installing flash attention...")
65
+ result = subprocess.run(
66
+ FLASH_ATTN_INSTALL["command"],
67
+ env=FLASH_ATTN_INSTALL["env"],
68
+ shell=FLASH_ATTN_INSTALL["shell"],
69
+ capture_output=True,
70
+ text=True
71
+ )
72
+ if result.returncode == 0:
73
+ logger.info("Flash attention installed successfully")
74
+ else:
75
+ logger.warning(f"Flash attention installation warning: {result.stderr}")
76
+ except Exception as e:
77
+ logger.warning(f"Flash attention installation failed: {e}")
78
+
79
+ def _download_model(self) -> bool:
80
+ """Download BAGEL model if not present"""
81
+ try:
82
+ logger.info("Downloading BAGEL model...")
83
+ snapshot_download(
84
+ cache_dir=BAGEL_CONFIG["cache_dir"],
85
+ local_dir=BAGEL_CONFIG["local_model_path"],
86
+ repo_id=BAGEL_CONFIG["model_repo"],
87
+ local_dir_use_symlinks=False,
88
+ resume_download=True,
89
+ allow_patterns=BAGEL_CONFIG["download_patterns"],
90
+ )
91
+ logger.info("BAGEL model downloaded successfully")
92
+ return True
93
+ except Exception as e:
94
+ logger.error(f"BAGEL model download failed: {e}")
95
+ return False
96
+
97
  def initialize(self) -> bool:
98
+ """Initialize BAGEL model"""
99
  if self.is_initialized:
100
  return True
101
+
102
  try:
103
+ # Download model if needed
104
+ if not os.path.exists(BAGEL_CONFIG["local_model_path"]):
105
+ if not self._download_model():
106
+ return False
107
 
108
+ logger.info("Initializing BAGEL model...")
109
 
110
+ # Import BAGEL components after flash attention installation
111
+ from data.data_utils import add_special_tokens, pil_img2rgb
112
+ from data.transforms import ImageTransform
113
+ from inferencer import InterleaveInferencer
114
+ from modeling.autoencoder import load_ae
115
+ from modeling.bagel.qwen2_navit import NaiveCache
116
+ from modeling.bagel import (
117
+ BagelConfig, Bagel, Qwen2Config, Qwen2ForCausalLM,
118
+ SiglipVisionConfig, SiglipVisionModel
119
  )
120
+ from modeling.qwen2 import Qwen2Tokenizer
121
+
122
+ model_path = BAGEL_CONFIG["local_model_path"]
123
 
124
+ # Load configurations
125
+ llm_config = Qwen2Config.from_json_file(os.path.join(model_path, "llm_config.json"))
126
+ llm_config.qk_norm = True
127
+ llm_config.tie_word_embeddings = False
128
+ llm_config.layer_module = "Qwen2MoTDecoderLayer"
129
+
130
+ vit_config = SiglipVisionConfig.from_json_file(os.path.join(model_path, "vit_config.json"))
131
+ vit_config.rope = False
132
+ vit_config.num_hidden_layers -= 1
133
+
134
+ # Load VAE
135
+ self.vae_model, vae_config = load_ae(local_path=os.path.join(model_path, "ae.safetensors"))
136
+
137
+ # Create BAGEL config
138
+ config = BagelConfig(
139
+ visual_gen=True,
140
+ visual_und=True,
141
+ llm_config=llm_config,
142
+ vit_config=vit_config,
143
+ vae_config=vae_config,
144
+ vit_max_num_patch_per_side=70,
145
+ connector_act='gelu_pytorch_tanh',
146
+ latent_patch_size=2,
147
+ max_latent_size=64,
148
  )
149
 
150
+ # Initialize model with empty weights
151
+ with init_empty_weights():
152
+ language_model = Qwen2ForCausalLM(llm_config)
153
+ vit_model = SiglipVisionModel(vit_config)
154
+ self.model = Bagel(language_model, vit_model, config)
155
+ self.model.vit_model.vision_model.embeddings.convert_conv2d_to_linear(vit_config, meta=True)
 
 
156
 
157
+ # Load tokenizer
158
+ self.tokenizer = Qwen2Tokenizer.from_pretrained(model_path)
159
+ self.tokenizer, new_token_ids, _ = add_special_tokens(self.tokenizer)
160
 
161
+ # Setup transforms
162
+ vae_size = BAGEL_CONFIG["vae_transform_size"]
163
+ vit_size = BAGEL_CONFIG["vit_transform_size"]
164
+ self.vae_transform = ImageTransform(vae_size[0], vae_size[1], vae_size[2])
165
+ self.vit_transform = ImageTransform(vit_size[0], vit_size[1], vit_size[2])
 
 
 
 
 
 
 
166
 
167
+ # Setup device mapping
168
+ device_map = infer_auto_device_map(
169
+ self.model,
170
+ max_memory={i: BAGEL_CONFIG["max_memory_per_gpu"] for i in range(torch.cuda.device_count())},
171
+ no_split_module_classes=["Bagel", "Qwen2MoTDecoderLayer"],
172
+ )
173
 
174
+ # Apply custom device mapping for critical modules
175
+ custom_mapping = get_bagel_device_map(self.device_config["gpu_count"])
176
+ device_map.update(custom_mapping)
177
 
178
+ # Load model with checkpoints
179
+ self.model = load_checkpoint_and_dispatch(
180
+ self.model,
181
+ checkpoint=os.path.join(model_path, "ema.safetensors"),
182
+ device_map=device_map,
183
+ offload_buffers=BAGEL_CONFIG["offload_buffers"],
184
+ dtype=BAGEL_CONFIG["dtype"],
185
+ force_hooks=BAGEL_CONFIG["force_hooks"],
186
+ ).eval()
 
 
 
 
 
 
 
 
 
 
187
 
188
+ # Initialize inferencer
189
+ self.inferencer = InterleaveInferencer(
190
+ model=self.model,
191
+ vae_model=self.vae_model,
192
+ tokenizer=self.tokenizer,
193
+ vae_transform=self.vae_transform,
194
+ vit_transform=self.vit_transform,
195
+ new_token_ids=new_token_ids,
196
  )
197
 
198
+ self.is_initialized = True
199
+ logger.info("BAGEL model initialized successfully")
200
+ return True
201
+
 
 
202
  except Exception as e:
203
+ logger.error(f"BAGEL initialization failed: {e}")
204
+ self.cleanup()
205
+ return False
206
+
207
+ @spaces.GPU(duration=120)
208
+ def analyze_image(self, image: Image.Image, prompt_type: str = "detailed_description") -> Tuple[str, Dict[str, Any]]:
209
+ """Analyze image using BAGEL model"""
 
 
 
210
  if not self.is_initialized:
211
  success = self.initialize()
212
  if not success:
213
+ return "BAGEL model not available", {"error": "Initialization failed"}
214
 
215
  try:
216
+ # Get appropriate prompt
217
+ system_prompt = BAGEL_PROMPTS.get(prompt_type, BAGEL_PROMPTS["detailed_description"])
 
 
 
 
218
 
219
+ # Prepare image for BAGEL
220
+ if image.mode != 'RGB':
221
+ image = image.convert('RGB')
222
 
223
+ # Run inference through BAGEL
224
+ logger.info("Running BAGEL inference...")
 
 
 
 
 
225
 
226
+ # Use inferencer to analyze the image
227
+ response = self.inferencer.inference_image_understanding(
228
+ image=image,
229
+ prompt=system_prompt,
230
+ max_new_tokens=BAGEL_CONFIG["max_new_tokens"],
231
+ temperature=BAGEL_CONFIG["temperature"],
232
+ top_p=BAGEL_CONFIG["top_p"],
233
+ do_sample=BAGEL_CONFIG["do_sample"]
234
+ )
235
 
236
  # Prepare metadata
237
  metadata = {
238
+ "model": "BAGEL-7B",
239
  "device": self.device_config["device"],
240
+ "confidence": 0.9, # BAGEL is highly reliable
241
+ "prompt_type": prompt_type,
242
+ "gpu_count": self.device_config.get("gpu_count", 1),
243
+ "processing_mode": "GPU" if self.device_config["use_gpu"] else "CPU"
244
  }
245
 
246
+ logger.info(f"BAGEL analysis complete: {len(response)} characters")
247
+ return response, metadata
248
 
249
  except Exception as e:
250
+ logger.error(f"BAGEL analysis failed: {e}")
251
+ return "Analysis failed", {"error": str(e), "model": "BAGEL-7B"}
252
 
253
+ def cleanup(self) -> None:
254
+ """Clean up BAGEL resources"""
255
  try:
256
+ if hasattr(self, 'inferencer') and self.inferencer is not None:
257
+ del self.inferencer
258
+ self.inferencer = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
 
260
+ if hasattr(self, 'vae_model') and self.vae_model is not None:
261
+ del self.vae_model
262
+ self.vae_model = None
 
263
 
264
+ super().cleanup()
265
+ logger.info("BAGEL resources cleaned up")
266
  except Exception as e:
267
+ logger.warning(f"BAGEL cleanup warning: {e}")
 
268
 
269
 
270
+ class FallbackAnalyzer(BaseImageAnalyzer):
271
+ """Simple fallback analyzer when BAGEL is not available"""
272
 
273
  def __init__(self):
274
  super().__init__()
 
 
275
 
276
  def initialize(self) -> bool:
277
+ """Fallback is always ready"""
278
+ self.is_initialized = True
279
+ return True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
 
281
  def analyze_image(self, image: Image.Image) -> Tuple[str, Dict[str, Any]]:
282
+ """Provide basic image description"""
 
 
 
 
 
283
  try:
284
+ # Basic image analysis
285
+ width, height = image.size
286
+ mode = image.mode
287
+
288
+ # Simple descriptive text based on image properties
289
+ aspect_ratio = width / height
290
+
291
+ if aspect_ratio > 1.5:
292
+ orientation = "landscape"
293
+ elif aspect_ratio < 0.75:
294
+ orientation = "portrait"
295
+ else:
296
+ orientation = "square"
297
 
298
+ description = f"A {orientation} photograph with {mode} color mode, {width}x{height} pixels. Professional image suitable for detailed analysis and prompt generation."
 
 
 
 
299
 
 
300
  metadata = {
301
+ "model": "Fallback",
302
+ "device": "cpu",
303
+ "confidence": 0.5,
304
+ "image_size": f"{width}x{height}",
305
+ "color_mode": mode,
306
+ "orientation": orientation
307
  }
308
 
 
309
  return description, metadata
310
 
311
  except Exception as e:
312
+ logger.error(f"Fallback analysis failed: {e}")
313
+ return "Basic image detected", {"error": str(e), "model": "Fallback"}
314
 
315
 
316
  class ModelManager:
317
+ """Manager for handling image analysis models"""
318
 
319
+ def __init__(self, preferred_model: str = "bagel"):
320
+ self.preferred_model = preferred_model
321
  self.analyzers = {}
322
  self.current_analyzer = None
323
 
 
326
  model_name = model_name or self.preferred_model
327
 
328
  if model_name not in self.analyzers:
329
+ if model_name == "bagel":
 
 
330
  self.analyzers[model_name] = BagelAnalyzer()
331
+ elif model_name == "fallback":
332
+ self.analyzers[model_name] = FallbackAnalyzer()
333
  else:
334
+ logger.warning(f"Unknown model: {model_name}, using fallback")
335
+ model_name = "fallback"
336
+ self.analyzers[model_name] = FallbackAnalyzer()
337
 
338
  return self.analyzers[model_name]
339
 
340
  def analyze_image(self, image: Image.Image, model_name: str = None) -> Tuple[str, Dict[str, Any]]:
341
  """Analyze image with specified or preferred model"""
342
+ # Try preferred model first
343
  analyzer = self.get_analyzer(model_name)
344
  if analyzer is None:
345
  return "No analyzer available", {"error": "Model not found"}
346
 
347
  success, result = safe_execute(analyzer.analyze_image, image)
348
+
349
+ if success and result[1].get("error") is None:
350
  return result
351
  else:
352
+ # Fallback to simple analyzer if main model fails
353
+ logger.warning(f"Primary model failed, using fallback: {result}")
354
+ fallback_analyzer = self.get_analyzer("fallback")
355
+ fallback_success, fallback_result = safe_execute(fallback_analyzer.analyze_image, image)
356
+
357
+ if fallback_success:
358
+ return fallback_result
359
+ else:
360
+ return "All analyzers failed", {"error": "Complete analysis failure"}
361
 
362
  def cleanup_all(self) -> None:
363
  """Clean up all model resources"""
 
365
  analyzer.cleanup()
366
  self.analyzers.clear()
367
  clean_memory()
368
+ logger.info("All analyzers cleaned up")
369
 
370
 
371
  # Global model manager instance
372
+ model_manager = ModelManager(preferred_model="bagel")
373
 
374
 
375
  def analyze_image(image: Image.Image, model_name: str = None) -> Tuple[str, Dict[str, Any]]:
376
  """
377
+ Convenience function for image analysis using BAGEL
378
 
379
  Args:
380
  image: PIL Image to analyze
381
+ model_name: Optional model name ("bagel" or "fallback")
382
 
383
  Returns:
384
  Tuple of (description, metadata)
 
389
  # Export main components
390
  __all__ = [
391
  "BaseImageAnalyzer",
392
+ "BagelAnalyzer",
393
+ "FallbackAnalyzer",
394
  "ModelManager",
395
  "model_manager",
396
  "analyze_image"