Toun / bmt_engine.py
babaTEEpe's picture
Upload 17 files
513d6d1 verified
import os
import torch
import torch.nn as nn
import numpy as np
import json
import logging
logger = logging.getLogger(__name__)
class BMTEngine:
"""
Inference Engine for Bi-modal Transformer (BMT) Dense Video Captioning.
Handles Proposal Generation and Captioning.
"""
def __init__(self, model_dir, device="cpu"):
self.model_dir = model_dir
self.device = device
self.tokenizer = None
self.caption_model = None
self.proposal_model = None
self.paths = {
"caption": os.path.join(model_dir, "best_cap_model.pt"),
"proposal": os.path.join(model_dir, "best_prop_model.pt"),
"i3d": os.path.join(model_dir, "rgb_imagenet.pt"),
"vggish": os.path.join(model_dir, "vggish_model.ckpt"),
"vocab": os.path.join(model_dir, "vocabulary.json")
}
def load_models(self):
try:
# Check for vocab first
if os.path.exists(self.paths["vocab"]):
with open(self.paths["vocab"], "r") as f:
self.vocab = json.load(f)
logger.info("BMT Vocabulary loaded.")
else:
logger.warning(f"Vocabulary not found at {self.paths['vocab']}. Defaulting to dummy vocab for now.")
self.vocab = {"word_to_idx": {}, "idx_to_word": {}}
# Load Captioning Model
if os.path.exists(self.paths["caption"]):
# Note: BMT typically uses a custom architecture. We'll need the model class definition.
# For now, we load the state dict and assume the architecture is handled.
self.caption_model = torch.load(self.paths["caption"], map_location=self.device)
logger.info("BMT Captioning module loaded.")
else:
logger.error(f"Captioning model not found at {self.paths['caption']}")
# Load Proposal Model
if os.path.exists(self.paths["proposal"]):
self.proposal_model = torch.load(self.paths["proposal"], map_location=self.device)
logger.info("BMT Proposal module loaded.")
else:
logger.error(f"Proposal model not found at {self.paths['proposal']}")
return True
except Exception as e:
logger.error(f"Error loading BMT models: {e}")
return False
def generate_proposals(self, video_features, audio_features=None):
"""
Runs the Proposal Generation module on extracted features.
Returns a list of (start, end) timestamps.
"""
# Placeholder implementation for proposal generation
# Real BMT uses a Transformer encoder/decoder for this
if self.proposal_model is None:
return [(0.0, 5.0)] # Fallback to single segment if model missing
# Mocking BMT proposal logic for now
# In a real implementation, we'd pass features through self.proposal_model
return [(0.0, 5.0), (5.0, 10.0)]
def generate_captions(self, video_features, audio_features=None, proposals=None):
"""
Generates captions for given video/audio features and proposed segments.
"""
if not proposals:
proposals = [(0.0, 10.0)]
results = []
for start, end in proposals:
# Logic to generate caption for the specific segment [start, end]
# 1. Slice features for the segment
# 2. Run through captioning model
# 3. Decode tokens
caption = f"Event detected between {start:.1f}s and {end:.1f}s" # Placeholder
results.append({
"start": start,
"end": end,
"caption": caption
})
return results
def run_inference(self, video_path, mode="video"):
"""
Complete pipeline: Features -> Proposals -> Captions
"""
# 1. Feature Extraction (To be implemented in bmt_utils.py)
# 2. Load features
# 3. Predict
logger.info(f"Running BMT inference on {video_path} in {mode} mode.")
# Placeholder for real inference
dummy_proposals = [(0.0, 2.5), (3.0, 7.5)]
return self.generate_captions(None, None, dummy_proposals)