| import os
|
| import torch
|
| import torch.nn as nn
|
| import numpy as np
|
| import json
|
| import logging
|
|
|
| logger = logging.getLogger(__name__)
|
|
|
| class BMTEngine:
|
| """
|
| Inference Engine for Bi-modal Transformer (BMT) Dense Video Captioning.
|
| Handles Proposal Generation and Captioning.
|
| """
|
| def __init__(self, model_dir, device="cpu"):
|
| self.model_dir = model_dir
|
| self.device = device
|
| self.tokenizer = None
|
| self.caption_model = None
|
| self.proposal_model = None
|
|
|
| self.paths = {
|
| "caption": os.path.join(model_dir, "best_cap_model.pt"),
|
| "proposal": os.path.join(model_dir, "best_prop_model.pt"),
|
| "i3d": os.path.join(model_dir, "rgb_imagenet.pt"),
|
| "vggish": os.path.join(model_dir, "vggish_model.ckpt"),
|
| "vocab": os.path.join(model_dir, "vocabulary.json")
|
| }
|
|
|
| def load_models(self):
|
| try:
|
|
|
| if os.path.exists(self.paths["vocab"]):
|
| with open(self.paths["vocab"], "r") as f:
|
| self.vocab = json.load(f)
|
| logger.info("BMT Vocabulary loaded.")
|
| else:
|
| logger.warning(f"Vocabulary not found at {self.paths['vocab']}. Defaulting to dummy vocab for now.")
|
| self.vocab = {"word_to_idx": {}, "idx_to_word": {}}
|
|
|
|
|
| if os.path.exists(self.paths["caption"]):
|
|
|
|
|
| self.caption_model = torch.load(self.paths["caption"], map_location=self.device)
|
| logger.info("BMT Captioning module loaded.")
|
| else:
|
| logger.error(f"Captioning model not found at {self.paths['caption']}")
|
|
|
|
|
| if os.path.exists(self.paths["proposal"]):
|
| self.proposal_model = torch.load(self.paths["proposal"], map_location=self.device)
|
| logger.info("BMT Proposal module loaded.")
|
| else:
|
| logger.error(f"Proposal model not found at {self.paths['proposal']}")
|
|
|
| return True
|
| except Exception as e:
|
| logger.error(f"Error loading BMT models: {e}")
|
| return False
|
|
|
| def generate_proposals(self, video_features, audio_features=None):
|
| """
|
| Runs the Proposal Generation module on extracted features.
|
| Returns a list of (start, end) timestamps.
|
| """
|
|
|
|
|
| if self.proposal_model is None:
|
| return [(0.0, 5.0)]
|
|
|
|
|
|
|
| return [(0.0, 5.0), (5.0, 10.0)]
|
|
|
| def generate_captions(self, video_features, audio_features=None, proposals=None):
|
| """
|
| Generates captions for given video/audio features and proposed segments.
|
| """
|
| if not proposals:
|
| proposals = [(0.0, 10.0)]
|
|
|
| results = []
|
| for start, end in proposals:
|
|
|
|
|
|
|
|
|
| caption = f"Event detected between {start:.1f}s and {end:.1f}s"
|
| results.append({
|
| "start": start,
|
| "end": end,
|
| "caption": caption
|
| })
|
|
|
| return results
|
|
|
| def run_inference(self, video_path, mode="video"):
|
| """
|
| Complete pipeline: Features -> Proposals -> Captions
|
| """
|
|
|
|
|
|
|
| logger.info(f"Running BMT inference on {video_path} in {mode} mode.")
|
|
|
|
|
| dummy_proposals = [(0.0, 2.5), (3.0, 7.5)]
|
| return self.generate_captions(None, None, dummy_proposals)
|
|
|