| import logging |
| import torch |
| import shutil |
| import gradio as gr |
| import numpy as np |
| from decord import VideoReader, cpu |
| from pathlib import Path |
| from typing import Any, Tuple, Dict, Optional, AsyncGenerator, List |
| import asyncio |
| from dataclasses import dataclass |
| from datetime import datetime |
| import cv2 |
| import copy |
|
|
| from llava.model.builder import load_pretrained_model |
| from llava.mm_utils import tokenizer_image_token |
| from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN |
| from llava.conversation import conv_templates, SeparatorStyle |
|
|
| from config import TRAINING_VIDEOS_PATH, STAGING_PATH, PRELOAD_CAPTIONING_MODEL, CAPTIONING_MODEL, USE_MOCK_CAPTIONING_MODEL, DEFAULT_CAPTIONING_BOT_INSTRUCTIONS, VIDEOS_TO_SPLIT_PATH, DEFAULT_PROMPT_PREFIX |
| from utils import extract_scene_info, is_image_file, is_video_file |
| from finetrainers_utils import copy_files_to_training_dir, prepare_finetrainers_dataset |
|
|
| logger = logging.getLogger(__name__) |
|
|
| @dataclass |
| class CaptioningProgress: |
| video_name: str |
| total_frames: int |
| processed_frames: int |
| status: str |
| started_at: datetime |
| completed_at: Optional[datetime] = None |
| error: Optional[str] = None |
|
|
| class CaptioningService: |
| _instance = None |
| _model = None |
| _tokenizer = None |
| _image_processor = None |
| _model_loading = None |
| _loop = None |
|
|
| def __new__(cls, model_name=CAPTIONING_MODEL): |
| if cls._instance is not None: |
| return cls._instance |
| |
| instance = super().__new__(cls) |
| if PRELOAD_CAPTIONING_MODEL: |
| cls._instance = instance |
| try: |
| cls._loop = asyncio.get_running_loop() |
| except RuntimeError: |
| cls._loop = asyncio.new_event_loop() |
| asyncio.set_event_loop(cls._loop) |
| |
| if not USE_MOCK_CAPTIONING_MODEL and cls._model_loading is None: |
| cls._model_loading = cls._loop.create_task(cls._background_load_model(model_name)) |
| return instance |
|
|
| def __init__(self, model_name=CAPTIONING_MODEL): |
| if hasattr(self, 'model_name'): |
| return |
| |
| self.model_name = model_name |
| self.tokenizer = None |
| self.model = None |
| self.image_processor = None |
| self.active_tasks: Dict[str, CaptioningProgress] = {} |
| self._should_stop = False |
| self._model_loaded = False |
|
|
| @classmethod |
| async def _background_load_model(cls, model_name): |
| """Background task to load the model""" |
| try: |
| logger.info("Starting background model loading...") |
| if not cls._loop: |
| cls._loop = asyncio.get_running_loop() |
| |
| def load_model(): |
| try: |
| tokenizer, model, image_processor, _ = load_pretrained_model( |
| model_name, None, "llava_qwen", |
| torch_dtype="bfloat16", device_map="auto" |
| ) |
| model.eval() |
| return tokenizer, model, image_processor |
| except Exception as e: |
| logger.error(f"Error in load_model: {str(e)}") |
| raise |
|
|
| result = await cls._loop.run_in_executor(None, load_model) |
| |
| cls._tokenizer, cls._model, cls._image_processor = result |
| logger.info("Background model loading completed successfully!") |
| |
| except Exception as e: |
| logger.error(f"Background model loading failed: {str(e)}") |
| cls._model_loading = None |
| raise |
|
|
| async def ensure_model_loaded(self): |
| """Ensure model is loaded before processing""" |
| if USE_MOCK_CAPTIONING_MODEL: |
| logger.info("Using mock model, skipping model loading") |
| self.__class__._model_loading = None |
| self._model_loaded = True |
| return |
|
|
| if not self._model_loaded: |
| try: |
| if PRELOAD_CAPTIONING_MODEL and self.__class__._model_loading: |
| logger.info("Waiting for background model loading to complete...") |
| if self.__class__._loop and self.__class__._loop != asyncio.get_running_loop(): |
| logger.warning("Different event loop detected, creating new loading task") |
| self.__class__._model_loading = None |
| await self._load_model_sync() |
| else: |
| await self.__class__._model_loading |
| self.model = self.__class__._model |
| self.tokenizer = self.__class__._tokenizer |
| self.image_processor = self.__class__._image_processor |
| else: |
| await self._load_model_sync() |
| |
| self._model_loaded = True |
| logger.info("Model loading completed!") |
| except Exception as e: |
| logger.error(f"Error loading model: {str(e)}") |
| raise |
| |
| async def _load_model_sync(self): |
| """Synchronously load the model""" |
| logger.info("Loading model synchronously...") |
| current_loop = asyncio.get_running_loop() |
| |
| def load_model(): |
| return load_pretrained_model( |
| self.model_name, None, "llava_qwen", |
| torch_dtype="bfloat16", device_map="auto" |
| ) |
| |
| self.tokenizer, self.model, self.image_processor, _ = await current_loop.run_in_executor( |
| None, load_model |
| ) |
| self.model.eval() |
| |
| def _load_video(self, video_path: Path, max_frames_num: int = 64, fps: int = 1, force_sample: bool = True) -> tuple[np.ndarray, str, float]: |
| """Load and preprocess video frames with strict limits |
| |
| Args: |
| video_path: Path to video file |
| max_frames_num: Maximum number of frames to extract (default: 64) |
| fps: Frames per second to sample (default: 1) |
| force_sample: Whether to force uniform sampling (default: True) |
| |
| Returns: |
| Tuple of (frames, frame_times_str, video_time) |
| """ |
| video_path_str = str(video_path) |
| logger.debug(f"Loading video: {video_path_str}") |
| |
| |
| if max_frames_num == 0: |
| return np.zeros((1, 336, 336, 3)), "", 0 |
| |
| vr = VideoReader(video_path_str, ctx=cpu(0), num_threads=1) |
| total_frame_num = len(vr) |
| video_time = total_frame_num / vr.get_avg_fps() |
| |
| |
| fps = round(vr.get_avg_fps()/fps) |
| frame_idx = [i for i in range(0, len(vr), fps)] |
| frame_time = [i/fps for i in frame_idx] |
| |
| |
| if len(frame_idx) > max_frames_num or force_sample: |
| sample_fps = max_frames_num |
| uniform_sampled_frames = np.linspace(0, total_frame_num - 1, sample_fps, dtype=int) |
| frame_idx = uniform_sampled_frames.tolist() |
| frame_time = [i/vr.get_avg_fps() for i in frame_idx] |
| |
| frame_time_str = ",".join([f"{i:.2f}s" for i in frame_time]) |
| |
| try: |
| frames = vr.get_batch(frame_idx).asnumpy() |
| logger.debug(f"Loaded {len(frames)} frames with shape {frames.shape}") |
| return frames, frame_time_str, video_time |
| except Exception as e: |
| logger.error(f"Error loading video frames: {str(e)}") |
| raise |
|
|
| async def process_video(self, video_path: Path, prompt: str, prompt_prefix: str = "") -> AsyncGenerator[tuple[CaptioningProgress, Optional[str]], None]: |
| try: |
| video_name = video_path.name |
| logger.info(f"Starting processing of video: {video_name}") |
| |
| |
| logger.debug(f"Loading video metadata for {video_name}") |
| loop = asyncio.get_event_loop() |
| vr = await loop.run_in_executor(None, lambda: VideoReader(str(video_path), ctx=cpu(0))) |
| total_frames = len(vr) |
| |
| progress = CaptioningProgress( |
| video_name=video_name, |
| total_frames=total_frames, |
| processed_frames=0, |
| status="initializing", |
| started_at=datetime.now() |
| ) |
| self.active_tasks[video_name] = progress |
| yield progress, None |
|
|
| |
| parent_caption = "" |
| if "___" in video_path.stem: |
| parent_name, _ = extract_scene_info(video_path.stem) |
| parent_txt_path = VIDEOS_TO_SPLIT_PATH / f"{parent_name}.txt" |
| if parent_txt_path.exists(): |
| parent_caption = parent_txt_path.read_text().strip() |
|
|
| |
| await self.ensure_model_loaded() |
|
|
| if USE_MOCK_CAPTIONING_MODEL: |
| |
| clip_caption = f"This is a test caption for {video_name}" |
|
|
| |
| if parent_caption: |
| full_caption = f"{clip_caption}\n{parent_caption}" |
| else: |
| full_caption = clip_caption |
|
|
| if prompt_prefix and not full_caption.startswith(prompt_prefix): |
| full_caption = f"{prompt_prefix}{full_caption}" |
| |
| |
| txt_path = video_path.with_suffix('.txt') |
| txt_path.write_text(full_caption) |
| |
| logger.debug(f"Mock mode: Saved caption to {txt_path}") |
|
|
| progress.status = "completed" |
| progress.processed_frames = total_frames |
| progress.completed_at = datetime.now() |
| yield progress, full_caption |
| else: |
| |
| max_frames_num = 64 |
| frames, frame_times_str, video_time = await loop.run_in_executor( |
| None, |
| lambda: self._load_video(video_path, max_frames_num, fps=1, force_sample=True) |
| ) |
| |
| |
| processed_frames = await loop.run_in_executor( |
| None, |
| lambda: self.image_processor.preprocess( |
| frames, |
| return_tensors="pt" |
| )["pixel_values"] |
| ) |
| |
| |
| progress.processed_frames = len(frames) |
| progress.status = "generating caption" |
| yield progress, None |
|
|
| |
| video_tensor = processed_frames.to('cuda').bfloat16() |
| |
| |
| conv_template = "qwen_1_5" |
| time_instruction = (f"The video lasts for {video_time:.2f} seconds, and {len(frames)} " |
| f"frames are uniformly sampled from it. These frames are located at {frame_times_str}.") |
| |
| full_question = DEFAULT_IMAGE_TOKEN + f"{time_instruction}\n{prompt}" |
| |
| conv = copy.deepcopy(conv_templates[conv_template]) |
| conv.append_message(conv.roles[0], full_question) |
| conv.append_message(conv.roles[1], None) |
| prompt_question = conv.get_prompt() |
| |
| |
| max_new_tokens = 512 |
|
|
| input_ids = await loop.run_in_executor( |
| None, |
| lambda: tokenizer_image_token(prompt_question, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to('cuda') |
| ) |
| |
| |
| with torch.no_grad(): |
| output = await loop.run_in_executor( |
| None, |
| lambda: self.model.generate( |
| input_ids, |
| images=[video_tensor], |
| modalities=["video"], |
| do_sample=False, |
| temperature=0, |
| max_new_tokens=max_new_tokens, |
| ) |
| ) |
|
|
| clip_caption = await loop.run_in_executor( |
| None, |
| lambda: self.tokenizer.batch_decode(output, skip_special_tokens=True)[0].strip() |
| ) |
|
|
| |
| if time_instruction in clip_caption: |
| clip_caption = clip_caption.split(time_instruction)[1].strip() |
| if prompt in clip_caption: |
| clip_caption = clip_caption.split(prompt)[1].strip() |
|
|
| |
| if parent_caption: |
| full_caption = f"{clip_caption}\n{parent_caption}" |
| else: |
| full_caption = clip_caption |
|
|
| if prompt_prefix and not full_caption.startswith(prompt_prefix): |
| full_caption = f"{prompt_prefix}{full_caption}" |
|
|
| |
| txt_path = video_path.with_suffix('.txt') |
| txt_path.write_text(full_caption) |
| |
| progress.status = "completed" |
| progress.completed_at = datetime.now() |
| yield progress, full_caption |
|
|
| except Exception as e: |
| progress.status = "error" |
| progress.error = str(e) |
| progress.completed_at = datetime.now() |
| yield progress, None |
| raise |
|
|
| async def process_image(self, image_path: Path, prompt: str, prompt_prefix: str = "") -> AsyncGenerator[tuple[CaptioningProgress, Optional[str]], None]: |
| """Process a single image for captioning""" |
| try: |
| image_name = image_path.name |
| logger.info(f"Starting processing of image: {image_name}") |
| |
| progress = CaptioningProgress( |
| video_name=image_name, |
| total_frames=1, |
| processed_frames=0, |
| status="initializing", |
| started_at=datetime.now() |
| ) |
| self.active_tasks[image_name] = progress |
| yield progress, None |
|
|
| |
| await self.ensure_model_loaded() |
|
|
| if USE_MOCK_CAPTIONING_MODEL: |
| progress.status = "completed" |
| progress.processed_frames = 1 |
| progress.completed_at = datetime.now() |
| print("yielding fake") |
| yield progress, "This is a test image caption" |
| return |
|
|
| |
| loop = asyncio.get_event_loop() |
| image = await loop.run_in_executor( |
| None, |
| lambda: cv2.imread(str(image_path)) |
| ) |
| if image is None: |
| raise ValueError(f"Could not read image: {str(image_path)}") |
| |
| |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
| |
| |
| processed_image = await loop.run_in_executor( |
| None, |
| lambda: self.image_processor.preprocess( |
| image, |
| return_tensors="pt" |
| )["pixel_values"] |
| ) |
| |
| progress.processed_frames = 1 |
| progress.status = "generating caption" |
| yield progress, None |
|
|
| |
| image_tensor = processed_image.to('cuda').bfloat16() |
| full_prompt = f"<image>{prompt}" |
|
|
| input_ids = await loop.run_in_executor( |
| None, |
| lambda: tokenizer_image_token(full_prompt, self.tokenizer, return_tensors="pt").unsqueeze(0).to('cuda') |
| ) |
| |
| with torch.no_grad(): |
| output = await loop.run_in_executor( |
| None, |
| lambda: self.model.generate( |
| input_ids, |
| images=[image_tensor], |
| modalities=["image"], |
| do_sample=False, |
| temperature=0, |
| max_new_tokens=4096, |
| ) |
| ) |
|
|
| caption = await loop.run_in_executor( |
| None, |
| lambda: self.tokenizer.batch_decode(output, skip_special_tokens=True)[0].strip() |
| ) |
| |
| progress.status = "completed" |
| progress.completed_at = datetime.now() |
| gr.Info(f"Successfully generated caption for {image_name}") |
| yield progress, caption |
|
|
| except Exception as e: |
| progress.status = "error" |
| progress.error = str(e) |
| progress.completed_at = datetime.now() |
| yield progress, None |
| raise gr.Error(f"Error processing image: {str(e)}") |
|
|
| |
| async def start_caption_generation(self, custom_prompt: str, prompt_prefix: str) -> AsyncGenerator[List[List[str]], None]: |
| """Iterates over clips to auto-generate captions asynchronously.""" |
| try: |
| logger.info("Starting auto-caption generation") |
| |
| |
| default_prompt = DEFAULT_CAPTIONING_BOT_INSTRUCTIONS |
| prompt = custom_prompt.strip() or default_prompt |
| logger.debug(f"Using prompt: {prompt}") |
|
|
| |
| video_files = list(STAGING_PATH.glob("*.mp4")) |
| image_files = [f for f in STAGING_PATH.glob("*") if is_image_file(f)] |
| all_files = video_files + image_files |
| |
| |
| files_to_process = [] |
| for file_path in all_files: |
| caption_path = file_path.with_suffix('.txt') |
| needs_caption = ( |
| not caption_path.exists() or |
| caption_path.stat().st_size == 0 or |
| caption_path.read_text().strip() == "" |
| ) |
| if needs_caption: |
| files_to_process.append(file_path) |
| |
| logger.info(f"Found {len(files_to_process)} files needing captions") |
| |
| if not files_to_process: |
| logger.info("No files need captioning") |
| yield [] |
| return |
|
|
| self._should_stop = False |
| self.active_tasks.clear() |
| status_update: Dict[str, Dict[str, Any]] = {} |
|
|
| for file_path in all_files: |
| if self._should_stop: |
| break |
|
|
| try: |
| print(f"we are in file_path {str(file_path)}") |
| |
| if is_video_file(file_path): |
| process_gen = self.process_video(file_path, prompt, prompt_prefix) |
| else: |
| process_gen = self.process_image(file_path, prompt, prompt_prefix) |
| print("got process_gen = ", process_gen) |
| async for progress, caption in process_gen: |
| print(f"process_gen contains this caption = {caption}") |
| if caption and prompt_prefix and not caption.startswith(prompt_prefix): |
| caption = f"{prompt_prefix}{caption}" |
| |
| |
| if caption: |
| txt_path = file_path.with_suffix('.txt') |
| txt_path.write_text(caption) |
| |
| logger.debug(f"Progress update: {progress.status}") |
| |
| |
| status_update[file_path.name] = { |
| "status": progress.status, |
| "frames": progress.processed_frames, |
| "total": progress.total_frames |
| } |
|
|
| |
| rows = [] |
| for file_name, info in status_update.items(): |
| status = info["status"] |
| if status == "processing": |
| percent = (info["frames"] / info["total"]) * 100 |
| status = f"Analyzing... {percent:.1f}% ({info['frames']}/{info['total']} frames)" |
| elif status == "generating caption": |
| status = "Generating caption..." |
| elif status == "error": |
| status = f"Error: {progress.error}" |
| elif status == "completed": |
| status = "Completed" |
| |
| rows.append([file_name, status]) |
|
|
| yield rows |
| await asyncio.sleep(0.1) |
| |
|
|
| except Exception as e: |
| logger.error(f"Error processing file {file_path}: {str(e)}", exc_info=True) |
| rows = [[str(file_path.name), f"Error: {str(e)}"]] |
| yield rows |
| continue |
|
|
| logger.info("Auto-caption generation completed, cyping assets to the training dir..") |
|
|
| copy_files_to_training_dir(prompt_prefix) |
| except Exception as e: |
| logger.error(f"Error in start_caption_generation: {str(e)}") |
| yield [[str(e), "error"]] |
| raise |
|
|
| def stop_captioning(self): |
| """Stop all ongoing captioning tasks""" |
| logger.info("Stopping all captioning tasks") |
| self._should_stop = True |
|
|
| def close(self): |
| """Clean up resources""" |
| logger.info("Cleaning up captioning service resources") |
| if hasattr(self, 'model'): |
| del self.model |
| if hasattr(self, 'tokenizer'): |
| del self.tokenizer |
| if hasattr(self, 'image_processor'): |
| del self.image_processor |
| torch.cuda.empty_cache() |