visionq / agents /caption_agent.py
NanG01's picture
fixes
57327cd
"""
CaptionAgent - Image captioning using TinyCLIP
UPDATED MODULE - Replaced BLIP with TinyCLIP for efficient captioning
"""
import os
os.environ["TRANSFORMERS_NO_TF"] = "1"
os.environ["HF_HUB_DISABLE_TF"] = "1"
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
from PIL import Image
import torch
import open_clip
class CaptionAgent:
def __init__(self):
print("[CaptionAgent] Loading TinyCLIP model...")
try:
# Load TinyCLIP using open_clip
self.model, _, self.preprocess = open_clip.create_model_and_transforms(
'ViT-B-32',
pretrained='laion400m_e32'
)
self.tokenizer = open_clip.get_tokenizer('ViT-B-32')
self.model.eval()
print("[CaptionAgent] Using OpenCLIP (fallback)")
except Exception as e:
print(f"[CaptionAgent] OpenCLIP failed: {e}")
# Fallback to standard CLIP
from transformers import CLIPProcessor, CLIPModel
self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
self.preprocess = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
self.tokenizer = None
self.model.eval()
print("[CaptionAgent] Using standard CLIP (fallback)")
# Common scene descriptions
self.scene_labels = [
"a person sitting at a desk",
"a person working on a computer",
"an empty room",
"a person standing",
"multiple people in a room",
"a workspace with electronics",
"a bedroom",
"a kitchen",
"a living room",
"outdoor scenery",
"a person reading",
"a person using a phone",
"furniture and objects",
"indoor environment",
"people interacting"
]
def describe(self, frame_bgr):
# OpenCV BGR → PIL RGB
frame_rgb = frame_bgr[:, :, ::-1]
image = Image.fromarray(frame_rgb)
try:
if self.tokenizer: # OpenCLIP path
image_input = self.preprocess(image).unsqueeze(0)
text_tokens = self.tokenizer(self.scene_labels)
with torch.no_grad():
image_features = self.model.encode_image(image_input)
text_features = self.model.encode_text(text_tokens)
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
top_idx = similarity.argmax().item()
else: # Transformers CLIP path
inputs = self.preprocess(
text=self.scene_labels,
images=image,
return_tensors="pt",
padding=True
)
with torch.no_grad():
outputs = self.model(**inputs)
logits_per_image = outputs.logits_per_image
probs = logits_per_image.softmax(dim=1)
top_idx = probs.argmax().item()
caption = self.scene_labels[top_idx]
return caption
except Exception as e:
print(f"[CaptionAgent] Error during inference: {e}")
return "a scene with various objects"