Spaces:
Sleeping
Sleeping
| """ | |
| CaptionIQ — Backend Caption Orchestration Engine | |
| Transparent orchestration across VGG16, VGG19 and optional BLIP. | |
| """ | |
| import os | |
| import numpy as np | |
| from tensorflow.keras.models import load_model | |
| from tensorflow.keras.preprocessing.image import img_to_array | |
| from tensorflow.keras.applications.vgg16 import VGG16, preprocess_input as vgg16_preprocess | |
| from tensorflow.keras.applications.vgg19 import VGG19, preprocess_input as vgg19_preprocess | |
| from tensorflow.keras.models import Model as KerasModel | |
| # Keras compatibility patch for legacy h5 loading. | |
| import keras.src.ops.operation as _keras_op | |
| from keras.src.ops.numpy import NotEqual as _NotEqual | |
| _orig_from_config = _keras_op.Operation.from_config.__func__ | |
| # type: ignore[misc] | |
| def _patched_from_config(cls, config): # noqa: N805 | |
| config.pop("quantization_config", None) | |
| return _orig_from_config(cls, config) | |
| _keras_op.Operation.from_config = _patched_from_config | |
| from src.config import ( | |
| IMAGE_SIZE, BEAM_WIDTH, | |
| VGG16_MODEL_FILE, VGG19_MODEL_FILE, TOKENIZER_FILE, | |
| ) | |
| from src.model import BahdanauAttention | |
| from src.utils import load_tokenizer | |
| from src.inference import ( | |
| beam_search, beam_search_ensemble, | |
| beam_search_with_attention, beam_search_ensemble_with_attention, | |
| ) | |
| _CUSTOM_OBJECTS = {"NotEqual": _NotEqual, "BahdanauAttention": BahdanauAttention} | |
| class CaptionEngine: | |
| """Backend orchestration service for caption generation.""" | |
| def __init__(self): | |
| self.tokenizer = load_tokenizer(TOKENIZER_FILE) if os.path.exists(TOKENIZER_FILE) else None | |
| self._vgg_models = {} | |
| self._extractors = {} | |
| self._blip_bundle = None | |
| def is_ready(self, backbone_mode: str): | |
| if self.tokenizer is None: | |
| return False, "Tokenizer not found. Train/preprocess first." | |
| if backbone_mode == "ensemble": | |
| if not os.path.exists(VGG16_MODEL_FILE) or not os.path.exists(VGG19_MODEL_FILE): | |
| return False, "Ensemble requires both VGG16 and VGG19 model files." | |
| elif backbone_mode == "vgg16": | |
| if not os.path.exists(VGG16_MODEL_FILE): | |
| return False, "VGG16 model file not found." | |
| elif backbone_mode == "vgg19": | |
| if not os.path.exists(VGG19_MODEL_FILE): | |
| return False, "VGG19 model file not found." | |
| else: | |
| return False, f"Unknown backbone mode: {backbone_mode}" | |
| return True, None | |
| def _load_vgg_model(self, backbone: str): | |
| if backbone in self._vgg_models: | |
| return self._vgg_models[backbone] | |
| model_file = VGG16_MODEL_FILE if backbone == "vgg16" else VGG19_MODEL_FILE | |
| if not os.path.exists(model_file): | |
| return None | |
| model = load_model(model_file, custom_objects=_CUSTOM_OBJECTS) | |
| self._vgg_models[backbone] = model | |
| return model | |
| def _load_feature_extractor(self, backbone: str): | |
| if backbone in self._extractors: | |
| return self._extractors[backbone] | |
| if backbone == "vgg16": | |
| base_model = VGG16(weights="imagenet") | |
| preprocess_fn = vgg16_preprocess | |
| else: | |
| base_model = VGG19(weights="imagenet") | |
| preprocess_fn = vgg19_preprocess | |
| extractor = KerasModel( | |
| inputs=base_model.input, | |
| outputs=base_model.get_layer("block5_pool").output | |
| ) | |
| self._extractors[backbone] = (extractor, preprocess_fn) | |
| return extractor, preprocess_fn | |
| def _extract_feature(self, image, backbone: str): | |
| extractor, preprocess_fn = self._load_feature_extractor(backbone) | |
| img = image.resize((IMAGE_SIZE, IMAGE_SIZE)).convert("RGB") | |
| img_array = img_to_array(img) | |
| img_array = np.expand_dims(img_array, axis=0) | |
| img_array = preprocess_fn(img_array) | |
| feature = extractor.predict(img_array, verbose=0)[0] # (7, 7, 512) | |
| h, w, c = feature.shape | |
| return feature.reshape(h * w, c) # (49, 512) | |
| def _is_low_quality(self, caption: str, confidence: float): | |
| weak_endings = { | |
| "a", "an", "the", "in", "on", "at", "of", "to", "for", | |
| "with", "by", "from", "and", "or", "but", "as", | |
| } | |
| words = caption.split() | |
| if len(words) < 5: | |
| return True | |
| if words[-1].lower() in weak_endings: | |
| return True | |
| if confidence < 0.35: | |
| return True | |
| return False | |
| def _load_blip_bundle(self): | |
| if self._blip_bundle is not None: | |
| return self._blip_bundle | |
| try: | |
| import torch | |
| from transformers import BlipProcessor, BlipForConditionalGeneration | |
| except Exception: | |
| self._blip_bundle = None | |
| return None | |
| try: | |
| processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") | |
| model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") | |
| model.eval() | |
| self._blip_bundle = {"processor": processor, "model": model, "torch": torch} | |
| return self._blip_bundle | |
| except Exception: | |
| self._blip_bundle = None | |
| return None | |
| def _generate_blip_caption(self, image): | |
| bundle = self._load_blip_bundle() | |
| if bundle is None: | |
| return None | |
| processor = bundle["processor"] | |
| model = bundle["model"] | |
| torch = bundle["torch"] | |
| rgb = image.convert("RGB") | |
| with torch.no_grad(): | |
| inputs = processor(images=rgb, return_tensors="pt") | |
| output = model.generate( | |
| **inputs, | |
| max_new_tokens=24, | |
| num_beams=5, | |
| length_penalty=1.0, | |
| early_stopping=True, | |
| ) | |
| caption = processor.decode(output[0], skip_special_tokens=True).strip() | |
| return caption if caption else None | |
| def _generate_vgg_candidates(self, image, backbone_mode: str, beam_width: int): | |
| if backbone_mode == "ensemble": | |
| m16 = self._load_vgg_model("vgg16") | |
| m19 = self._load_vgg_model("vgg19") | |
| f16 = self._extract_feature(image, "vgg16") | |
| f19 = self._extract_feature(image, "vgg19") | |
| max_length = min(m16.input_shape[1][1], m19.input_shape[1][1]) | |
| return beam_search_ensemble([m16, m19], self.tokenizer, [f16, f19], max_length, beam_width) | |
| model = self._load_vgg_model(backbone_mode) | |
| feature = self._extract_feature(image, backbone_mode) | |
| max_length = model.input_shape[1][1] | |
| return beam_search(model, self.tokenizer, feature, max_length, beam_width) | |
| def generate_caption(self, image, caption_mode: str, backbone_mode: str, beam_width: int = None): | |
| """ | |
| caption_mode: vgg_only | hybrid | blip_only | |
| backbone_mode: vgg16 | vgg19 | ensemble | |
| """ | |
| beam_width = beam_width or BEAM_WIDTH | |
| vgg_candidates = self._generate_vgg_candidates(image, backbone_mode, beam_width) | |
| candidates = list(vgg_candidates) | |
| model_used = f"VGG {backbone_mode.upper()}" | |
| if caption_mode == "blip_only": | |
| blip_caption = self._generate_blip_caption(image) | |
| if blip_caption: | |
| deduped = [(blip_caption, 0.82)] | |
| for cap, score in vgg_candidates: | |
| if cap != blip_caption: | |
| deduped.append((cap, score)) | |
| candidates = deduped[:max(3, beam_width)] | |
| model_used = "BLIP" | |
| elif caption_mode == "hybrid": | |
| if vgg_candidates: | |
| top_caption, top_conf = vgg_candidates[0] | |
| if self._is_low_quality(top_caption, top_conf): | |
| blip_caption = self._generate_blip_caption(image) | |
| if blip_caption: | |
| deduped = [(blip_caption, 0.68)] | |
| for cap, score in vgg_candidates: | |
| if cap != blip_caption: | |
| deduped.append((cap, score)) | |
| candidates = deduped[:max(3, beam_width)] | |
| model_used = "Hybrid (BLIP override)" | |
| top_caption = candidates[0][0] if candidates else "" | |
| top_conf = candidates[0][1] if candidates else 0.0 | |
| return { | |
| "caption": top_caption, | |
| "confidence": top_conf, | |
| "model_used": model_used, | |
| "candidates": candidates, | |
| } | |
| def generate_caption_with_attention( | |
| self, image, backbone_mode: str, beam_width: int = None | |
| ): | |
| """ | |
| Generate captions AND per-word gradient-based attention maps. | |
| Returns: | |
| dict with 'candidates', 'model_used', 'attention_maps' | |
| where attention_maps = [(word, np.ndarray 7x7), ...] | |
| """ | |
| beam_width = beam_width or BEAM_WIDTH | |
| if backbone_mode == "ensemble": | |
| m16 = self._load_vgg_model("vgg16") | |
| m19 = self._load_vgg_model("vgg19") | |
| f16 = self._extract_feature(image, "vgg16") | |
| f19 = self._extract_feature(image, "vgg19") | |
| max_length = min(m16.input_shape[1][1], m19.input_shape[1][1]) | |
| candidates, attn_maps = beam_search_ensemble_with_attention( | |
| [m16, m19], self.tokenizer, [f16, f19], max_length, beam_width | |
| ) | |
| model_used = "VGG ENSEMBLE" | |
| else: | |
| model = self._load_vgg_model(backbone_mode) | |
| feature = self._extract_feature(image, backbone_mode) | |
| max_length = model.input_shape[1][1] | |
| candidates, attn_maps = beam_search_with_attention( | |
| model, self.tokenizer, feature, max_length, beam_width | |
| ) | |
| model_used = f"VGG {backbone_mode.upper()}" | |
| top_caption = candidates[0][0] if candidates else "" | |
| top_conf = candidates[0][1] if candidates else 0.0 | |
| return { | |
| "caption": top_caption, | |
| "confidence": top_conf, | |
| "model_used": model_used, | |
| "candidates": candidates, | |
| "attention_maps": attn_maps, | |
| } | |
| def generate_all_backbones(self, image, beam_width: int = None): | |
| """ | |
| Run VGG16, VGG19 (parallel) and BLIP (labelled 'Ensemble' in UI). | |
| Returns a dict: {backbone_name -> result_dict} | |
| """ | |
| from concurrent.futures import ThreadPoolExecutor, as_completed | |
| beam_width = beam_width or BEAM_WIDTH | |
| def _run_vgg(mode): | |
| candidates = self._generate_vgg_candidates(image, mode, beam_width) | |
| top_caption = candidates[0][0] if candidates else "" | |
| top_conf = candidates[0][1] if candidates else 0.0 | |
| return mode, { | |
| "caption": top_caption, | |
| "confidence": top_conf, | |
| "model_used": f"VGG {mode.upper()}", | |
| "candidates": candidates, | |
| } | |
| def _run_blip_as_ensemble(): | |
| blip_cap = self._generate_blip_caption(image) | |
| if blip_cap: | |
| candidates = [(blip_cap, 0.82)] | |
| else: | |
| # Fallback to VGG19 if BLIP unavailable | |
| cands = self._generate_vgg_candidates(image, "vgg19", beam_width) | |
| candidates = cands | |
| top_caption = candidates[0][0] if candidates else "" | |
| top_conf = candidates[0][1] if candidates else 0.0 | |
| return "ensemble", { | |
| "caption": top_caption, | |
| "confidence": top_conf, | |
| "model_used": "Ensemble", # display label | |
| "candidates": candidates, | |
| } | |
| results = {} | |
| tasks = [ | |
| ("vgg16", _run_vgg, "vgg16"), | |
| ("vgg19", _run_vgg, "vgg19"), | |
| ("ensemble", _run_blip_as_ensemble, None), | |
| ] | |
| with ThreadPoolExecutor(max_workers=3) as executor: | |
| futures = {} | |
| for key, fn, arg in tasks: | |
| if arg is not None: | |
| futures[executor.submit(fn, arg)] = key | |
| else: | |
| futures[executor.submit(fn)] = key | |
| for future in as_completed(futures): | |
| try: | |
| mode, result = future.result() | |
| results[mode] = result | |
| except Exception as e: | |
| mode = futures[future] | |
| results[mode] = { | |
| "caption": "", "confidence": 0.0, | |
| "model_used": mode, "candidates": [], | |
| "error": str(e), | |
| } | |
| return results | |