video_class / app.py
shreyas27's picture
Update app.py
30e3ad8 verified
import gradio as gr
from transformers import VideoMAEImageProcessor, VideoMAEForVideoClassification
import torch
import numpy as np
import decord
from decord import VideoReader
import logging
import os
# --- Constants ---
NUM_FRAMES_TO_SAMPLE = 16
# --- Logging Setup ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# --- Decord Bridge Setup ---
try:
decord.bridge.set_bridge('torch')
logger.info("Decord bridge successfully set to PyTorch.")
except RuntimeError as e:
logger.warning(f"Failed to set decord bridge to PyTorch: {e}. "
"Ensure decord is compiled with PyTorch support (e.g., pip install decord[torch]). "
"Processing might fall back to CPU-based NumPy arrays if not correctly configured, "
"which will then be moved to the target device.")
# --- Device Configuration ---
if torch.cuda.is_available():
device = torch.device("cuda")
logger.info("CUDA is available. Using GPU.")
try:
decord_ctx = decord.gpu(torch.cuda.current_device() if hasattr(torch, 'cuda') and torch.cuda.is_available() else 0)
logger.info(f"Decord will attempt to use GPU context: {decord_ctx}")
except Exception as e:
logger.warning(f"Could not set decord context to GPU, falling back to CPU for decord. Error: {e}")
decord_ctx = decord.cpu(0)
logger.info(f"Decord will use CPU context: {decord_ctx}")
else:
device = torch.device("cpu")
logger.info("CUDA not available. Using CPU.")
decord_ctx = decord.cpu(0)
logger.info(f"Decord will use CPU context: {decord_ctx}")
# --- Model and Processor Loading ---
model = None
processor = None
try:
logger.info(f"Loading VideoMAEForVideoClassification model: OPear/videomae-large-finetuned-UCF-Crime to device: {device}")
model = VideoMAEForVideoClassification.from_pretrained("OPear/videomae-large-finetuned-UCF-Crime").to(device)
model.eval()
logger.info("Model loaded successfully.")
logger.info("Loading VideoMAEImageProcessor: MCG-NJU/videomae-base")
processor = VideoMAEImageProcessor.from_pretrained("MCG-NJU/videomae-base")
logger.info("Processor loaded successfully.")
except Exception as e:
logger.error(f"FATAL: Error loading model or processor during startup: {e}", exc_info=True)
raise
# --- Video Classification Function ---
def classify_video(video_filepath):
logger.info(f"--- New classification request for: '{video_filepath}' ---")
if model is None or processor is None:
logger.error("Model or processor is not loaded. Cannot classify.")
return "Error: Model or processor not available. Check server logs for loading errors."
if not video_filepath:
logger.warning("video_filepath is None or empty. This shouldn't happen with Gradio's Video component.")
return "Error: No video file was provided or the path is empty."
if not isinstance(video_filepath, str) or not os.path.exists(video_filepath):
logger.error(f"Error: Video file does not exist or path is invalid: '{video_filepath}' (Type: {type(video_filepath)})")
if not isinstance(video_filepath, str):
return f"Error: Expected a file path string, but received {type(video_filepath)}. Issue with Gradio input handling."
return "Error: Video file not found on server. Please ensure the file was uploaded correctly."
vr = None
try:
logger.info(f"Attempting to read video '{video_filepath}' using decord with context: {decord_ctx}")
vr = VideoReader(video_filepath, ctx=decord_ctx)
total_frames = len(vr)
if total_frames == 0:
logger.error(f"Error: Video at '{video_filepath}' is empty (total_frames is 0).")
return "Error: The video is empty, corrupted, or in an unsupported format."
height, width, channels = -1, -1, -1
try:
video_shape_tuple = vr.get_shape()
if len(video_shape_tuple) == 4:
height = video_shape_tuple[1]
width = video_shape_tuple[2]
channels = video_shape_tuple[3]
logger.info(f"Video dimensions obtained via get_shape(): H={height}, W={width}, C={channels}")
else:
logger.warning(f"vr.get_shape() returned an unexpected tuple format: {video_shape_tuple}. Will try fallback.")
except AttributeError:
logger.warning("'VideoReader' object has no attribute 'get_shape'. Falling back to reading the first frame for dimensions.")
except Exception as e_get_shape:
logger.warning(f"Error calling vr.get_shape(): {e_get_shape}. Will try fallback to first frame.")
if height == -1:
logger.info("Attempting to get dimensions from the first video frame.")
if total_frames > 0:
try:
first_frame = vr[0]
frame_s = first_frame.shape
if len(frame_s) == 3:
height = frame_s[0]
width = frame_s[1]
channels = frame_s[2]
logger.info(f"Video dimensions obtained from first frame shape: H={height}, W={width}, C={channels}")
else:
logger.error(f"First frame has an unexpected shape: {frame_s}. Cannot determine H, W, C.")
except Exception as e_frame_shape:
logger.error(f"Could not determine video dimensions from the first frame: {e_frame_shape}", exc_info=True)
else:
logger.error("Cannot get frame dimensions as video has no frames.")
if height == -1 or width == -1 or channels == -1:
logger.error(f"Failed to determine video dimensions for '{video_filepath}'. Aborting classification.")
return "Error: Could not determine video dimensions. The video might be corrupted, or there might be an issue with the Decord library version/installation."
logger.info(f"Video loaded. Total frames: {total_frames}. Determined Dimensions (H, W, C): ({height}, {width}, {channels})")
if total_frames < NUM_FRAMES_TO_SAMPLE:
logger.warning(f"Video duration ({total_frames} frames) is less than the desired {NUM_FRAMES_TO_SAMPLE} frames. Sampling all {total_frames} available frames.")
indices = np.arange(total_frames)
else:
indices = np.linspace(0, total_frames - 1, NUM_FRAMES_TO_SAMPLE, dtype=int)
logger.info(f"Selected frame indices for sampling: {indices.tolist()}")
video_frames_batch = vr.get_batch(indices)
logger.info(f"Data type from vr.get_batch: {type(video_frames_batch)}")
if isinstance(video_frames_batch, torch.Tensor):
logger.info(f"Decord returned PyTorch tensor. Shape: {video_frames_batch.shape}, Dtype: {video_frames_batch.dtype}, Device: {video_frames_batch.device}")
video_frames_tensor = video_frames_batch.to(device)
elif hasattr(video_frames_batch, 'asnumpy'): # Check if it's a Decord NDArray (or similar with asnumpy)
logger.info(f"Decord returned an object with 'asnumpy' method (e.g., decord.NDArray). Original shape: {video_frames_batch.shape}. Converting to NumPy array first.")
numpy_frames = video_frames_batch.asnumpy()
logger.info(f"Converted to NumPy array with shape: {numpy_frames.shape}, Dtype: {numpy_frames.dtype}")
video_frames_tensor = torch.from_numpy(numpy_frames).to(device)
elif isinstance(video_frames_batch, np.ndarray): # If it's already a NumPy array
logger.info(f"Decord returned a NumPy ndarray. Shape: {video_frames_batch.shape}, Dtype: {video_frames_batch.dtype}. Converting to PyTorch Tensor.")
video_frames_tensor = torch.from_numpy(video_frames_batch).to(device)
else:
logger.error(f"Video data from vr.get_batch is of unexpected type: {type(video_frames_batch)}. Cannot convert to PyTorch Tensor.")
return f"Error: Unexpected video data format ({type(video_frames_batch)} from Decord)."
logger.info(f"Video frames tensor prepared. Shape: {video_frames_tensor.shape}, Device: {video_frames_tensor.device}, Dtype: {video_frames_tensor.dtype}")
frames_list = list(video_frames_tensor)
logger.info(f"Processing {len(frames_list)} frames with VideoMAEImageProcessor.")
inputs = processor(frames_list, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
if 'pixel_values' in inputs:
logger.info(f"Frames processed. Input tensor 'pixel_values' shape: {inputs['pixel_values'].shape}, Device: {inputs['pixel_values'].device}")
else:
logger.warning(f"Processed inputs do not contain 'pixel_values'. Keys found: {list(inputs.keys())}")
with torch.no_grad():
logger.info("Performing model inference...")
outputs = model(**inputs)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
logger.info(f"Model inference complete. Predicted class index: {predicted_class_idx}")
predicted_label = model.config.id2label.get(predicted_class_idx, f"Unknown label (Index: {predicted_class_idx})")
logger.info(f"Predicted label: '{predicted_label}'")
return f"Predicted class: {predicted_label}"
except decord.DECORDError as e:
logger.error(f"Decord error while processing video '{video_filepath}': {e}", exc_info=True)
return f"Error processing video with decord: {str(e)}. The video file might be corrupted or in an unsupported format."
except RuntimeError as e:
logger.error(f"PyTorch runtime error during classification for '{video_filepath}': {e}", exc_info=True)
if "CUDA out of memory" in str(e):
return "Error: CUDA out of memory. The video might be too resource-intensive for the available GPU memory. Try a shorter or lower-resolution video."
return f"A runtime error occurred during processing: {str(e)}. Check server logs for details."
except AttributeError as e: # Catch other AttributeErrors that might occur
logger.error(f"Attribute error during video processing for '{video_filepath}': {e}", exc_info=True)
return f"An attribute error occurred: {str(e)}. This might indicate an issue with an object's properties or library versions."
except Exception as e:
logger.error(f"Unhandled error during video classification for '{video_filepath}': {e}", exc_info=True)
return f"An unexpected error occurred: {str(e)}. Please check server logs for details."
finally:
if vr is not None:
del vr
logger.info(f"VideoReader object for '{video_filepath}' dereferenced.")
# --- Gradio UI Components ---
video_input_component = gr.Video(
label="Upload Crime Video",
sources=["upload"]
)
text_output_component = gr.Textbox(
label="Classification Result"
)
example_video_paths = []
# --- Gradio Interface Setup ---
if 'classify_video' in globals() and callable(classify_video):
logger.info(f"'classify_video' is defined and callable. Type: {type(classify_video)}")
else:
logger.error("'classify_video' is NOT defined or is not callable. This will cause Gradio errors.")
import sys
sys.exit("Critical error: classify_video function not available for Gradio Interface.")
iface = gr.Interface(
fn=classify_video,
inputs=video_input_component,
outputs=text_output_component,
title="Video Crime Classification (GPU Accelerated)",
description=(
"Upload a video to classify the type of crime depicted. "
"Uses VideoMAE model (OPear/videomae-large-finetuned-UCF-Crime) fine-tuned on UCF-Crime. "
"Processor: MCG-NJU/videomae-base. Processing runs on GPU if available."
),
examples=example_video_paths if example_video_paths else None,
allow_flagging="never",
)
# --- Launch the Gradio App ---
if __name__ == "__main__":
logger.info("Attempting to launch Gradio application...")
try:
iface.launch(server_name="0.0.0.0")
logger.info("Gradio application launched. Access it via the URL provided in the console.")
except Exception as e:
logger.error(f"FATAL: Failed to launch Gradio interface: {e}", exc_info=True)