shreyas27 commited on
Commit
8836c5c
·
verified ·
1 Parent(s): a49d68b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -50
app.py CHANGED
@@ -5,47 +5,36 @@ import numpy as np
5
  import decord
6
  from decord import VideoReader
7
  import logging
8
- import os # Import os for path checking
9
 
10
- # --- Configure Logging ---
11
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
12
- logger = logging.getLogger(__name__) # Corrected from _name to __name__
13
 
14
- # --- Initialize decord bridge to PyTorch ---
15
- # This allows VideoReader.get_batch() to return PyTorch tensors directly
16
- # It's crucial for efficient GPU processing with Decord.
17
  try:
18
  decord.bridge.set_bridge('torch')
19
  logger.info("Decord bridge successfully set to PyTorch.")
20
  except RuntimeError as e:
21
  logger.warning(f"Failed to set decord bridge to PyTorch: {e}. "
22
  "Ensure decord is compiled with PyTorch support (e.g., pip install decord[torch]). "
23
- "Falling back to default bridge (numpy/cpu) if not set correctly later, "
24
- "which might require manual tensor conversion.")
25
- # If the bridge cannot be set, VideoReader might default to NumPy, requiring
26
- # explicit .to(device) and potentially .permute() on the NumPy array before processing.
27
- # However, the current code assumes a torch tensor output from get_batch().
28
- pass # Continue, as the code attempts to move to device later
29
 
30
-
31
- # --- Determine device (GPU if available, otherwise CPU) ---
32
  if torch.cuda.is_available():
33
  device = torch.device("cuda")
34
  logger.info("CUDA is available. Using GPU.")
35
- # decord context for GPU - use GPU 0 by default
36
  decord_ctx = decord.gpu(0)
37
- logger.info(f"Decord will use GPU: {decord_ctx}")
38
  else:
39
  device = torch.device("cpu")
40
  logger.info("CUDA not available. Using CPU.")
41
- decord_ctx = decord.cpu(0) # decord context for CPU
42
  logger.info(f"Decord will use CPU: {decord_ctx}")
43
 
44
- # --- Load Model and Processor ---
45
  try:
46
  logger.info(f"Loading VideoMAEForVideoClassification model: OPear/videomae-large-finetuned-UCF-Crime to device: {device}")
47
  model = VideoMAEForVideoClassification.from_pretrained("OPear/videomae-large-finetuned-UCF-Crime").to(device)
48
- model.eval() # Set model to evaluation mode for inference
49
  logger.info("Model loaded successfully.")
50
 
51
  logger.info("Loading VideoMAEImageProcessor: MCG-NJU/videomae-base")
@@ -53,15 +42,12 @@ try:
53
  logger.info("Processor loaded successfully.")
54
  except Exception as e:
55
  logger.error(f"FATAL: Error loading model or processor during startup: {e}", exc_info=True)
56
- # Re-raise the exception to prevent the app from starting if essential components fail to load
57
  raise
58
 
59
- # --- Video Classification Function ---
60
  def classify_video(video_filepath):
61
  logger.info(f"--- New classification request ---")
62
  logger.info(f"Received video_filepath: '{video_filepath}' (type: {type(video_filepath)})")
63
 
64
- # Basic input validation for Gradio's video component output
65
  if not video_filepath or not os.path.exists(video_filepath):
66
  logger.error(f"Error: video_filepath is None, empty, or file does not exist: '{video_filepath}'")
67
  return "Error: No valid video file received by the server. Please ensure the file exists and try uploading again."
@@ -76,44 +62,33 @@ def classify_video(video_filepath):
76
  logger.error(f"Error: Video at '{video_filepath}' is empty or could not be read (duration is 0).")
77
  return "Error: The video is empty or cannot be processed. It might be corrupted or in an unsupported format."
78
 
79
- num_frames_to_sample = 16 # Standard for VideoMAE
80
 
81
  if duration < num_frames_to_sample:
82
  logger.warning(f"Video duration ({duration} frames) is less than the desired {num_frames_to_sample} frames. Sampling all {duration} available frames.")
83
  indices = np.arange(duration)
84
  else:
85
- # Sample `num_frames_to_sample` evenly spaced frames
86
  indices = np.linspace(0, duration - 1, num_frames_to_sample, dtype=int)
87
 
88
  logger.info(f"Selected frame indices for sampling: {indices}")
89
 
90
- # .get_batch() will return PyTorch tensors on the specified decord_ctx (e.g., GPU)
91
- # if decord.bridge.set_bridge('torch') was successful.
92
  video_frames_tensor = vr.get_batch(indices)
93
 
94
- # Ensure frames are on the correct device, useful if decord bridge isn't set or context is CPU
95
  video_frames_tensor = video_frames_tensor.to(device)
96
  logger.info(f"Video frames successfully extracted and moved to device. Shape: {video_frames_tensor.shape}, Device: {video_frames_tensor.device}")
97
 
98
- # The processor expects a list of frames (PyTorch tensors in this case).
99
- # Assuming decord returns frames in (N, H, W, C) format (numpy default for 3D),
100
- # or (N, C, H, W) if bridge is torch and correctly configured.
101
- # VideoMAEImageProcessor expects (N, H, W, C) for its input list,
102
- # and it will handle the permutation to (N, C, H, W) internally if needed.
103
  inputs = processor(list(video_frames_tensor), return_tensors="pt")
104
 
105
- # Move processed inputs (e.g., pixel_values) to the same device as the model
106
  inputs = {k: v.to(device) for k, v in inputs.items()}
107
  logger.info(f"Frames processed by ImageProcessor and input tensors moved to device: {device}")
108
 
109
- with torch.no_grad(): # Disable gradient calculation for inference
110
  outputs = model(**inputs)
111
  logits = outputs.logits
112
  predicted_class_idx = logits.argmax(-1).item()
113
 
114
  logger.info(f"Model inference complete. Predicted class index: {predicted_class_idx}")
115
 
116
- # Get the human-readable label
117
  predicted_label = model.config.id2label[predicted_class_idx]
118
  logger.info(f"Predicted label: '{predicted_label}'")
119
 
@@ -123,38 +98,26 @@ def classify_video(video_filepath):
123
  logger.error(f"Error during video classification for '{video_filepath}': {e}", exc_info=True)
124
  return f"Error during classification: {str(e)}. Please check the video format, ensure decord dependencies are met, or review server logs for more details."
125
 
126
- # --- Gradio Interface Setup ---
127
  video_input_component = gr.Video(
128
  label="Upload Crime Video",
129
- # type="filepath" is removed as it's deprecated or not needed in newer Gradio versions
130
- # Gradio's gr.Video typically returns a filepath by default when uploaded.
131
  )
132
  text_output_component = gr.Textbox(
133
  label="Classification Result"
134
  )
135
 
136
- # Example video paths (replace with actual paths on your server if running locally
137
- # or ensure these paths are accessible in the deployment environment).
138
- # For deployment, often you provide a sample video file alongside your app.py.
139
  example_video_paths = [
140
- # "examples/crime_video_1.mp4",
141
- # "examples/crime_video_2.mp4",
142
- # Add actual paths here if you have example videos
143
  ]
144
 
145
-
146
  iface = gr.Interface(
147
  fn=classify_video,
148
  inputs=video_input_component,
149
  outputs=text_output_component,
150
  title="Video Crime Classification (GPU Accelerated)",
151
  description="Upload a video to classify the type of crime depicted using a VideoMAE model fine-tuned on UCF-Crime. Processing runs on GPU if available.",
152
- examples=example_video_paths, # Provide actual paths if you use examples
153
- allow_flagging="never" # Disables the "Flag" button
154
  )
155
 
156
- # --- Launch Gradio Application ---
157
  if __name__ == "__main__":
158
  logger.info("Starting Gradio application...")
159
- # server_name="0.0.0.0" makes the app accessible externally, useful for deployment
160
- iface.launch(server_name="0.0.0.0")
 
5
  import decord
6
  from decord import VideoReader
7
  import logging
8
+ import os
9
 
 
10
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
11
+ logger = logging.getLogger(__name__)
12
 
 
 
 
13
  try:
14
  decord.bridge.set_bridge('torch')
15
  logger.info("Decord bridge successfully set to PyTorch.")
16
  except RuntimeError as e:
17
  logger.warning(f"Failed to set decord bridge to PyTorch: {e}. "
18
  "Ensure decord is compiled with PyTorch support (e.g., pip install decord[torch]). "
19
+ "Processing might fall back to CPU-based NumPy arrays if not correctly configured, "
20
+ "which will then be moved to the target device.")
21
+ pass
 
 
 
22
 
 
 
23
  if torch.cuda.is_available():
24
  device = torch.device("cuda")
25
  logger.info("CUDA is available. Using GPU.")
 
26
  decord_ctx = decord.gpu(0)
27
+ logger.info(f"Decord will attempt to use GPU: {decord_ctx}")
28
  else:
29
  device = torch.device("cpu")
30
  logger.info("CUDA not available. Using CPU.")
31
+ decord_ctx = decord.cpu(0)
32
  logger.info(f"Decord will use CPU: {decord_ctx}")
33
 
 
34
  try:
35
  logger.info(f"Loading VideoMAEForVideoClassification model: OPear/videomae-large-finetuned-UCF-Crime to device: {device}")
36
  model = VideoMAEForVideoClassification.from_pretrained("OPear/videomae-large-finetuned-UCF-Crime").to(device)
37
+ model.eval()
38
  logger.info("Model loaded successfully.")
39
 
40
  logger.info("Loading VideoMAEImageProcessor: MCG-NJU/videomae-base")
 
42
  logger.info("Processor loaded successfully.")
43
  except Exception as e:
44
  logger.error(f"FATAL: Error loading model or processor during startup: {e}", exc_info=True)
 
45
  raise
46
 
 
47
  def classify_video(video_filepath):
48
  logger.info(f"--- New classification request ---")
49
  logger.info(f"Received video_filepath: '{video_filepath}' (type: {type(video_filepath)})")
50
 
 
51
  if not video_filepath or not os.path.exists(video_filepath):
52
  logger.error(f"Error: video_filepath is None, empty, or file does not exist: '{video_filepath}'")
53
  return "Error: No valid video file received by the server. Please ensure the file exists and try uploading again."
 
62
  logger.error(f"Error: Video at '{video_filepath}' is empty or could not be read (duration is 0).")
63
  return "Error: The video is empty or cannot be processed. It might be corrupted or in an unsupported format."
64
 
65
+ num_frames_to_sample = 16
66
 
67
  if duration < num_frames_to_sample:
68
  logger.warning(f"Video duration ({duration} frames) is less than the desired {num_frames_to_sample} frames. Sampling all {duration} available frames.")
69
  indices = np.arange(duration)
70
  else:
 
71
  indices = np.linspace(0, duration - 1, num_frames_to_sample, dtype=int)
72
 
73
  logger.info(f"Selected frame indices for sampling: {indices}")
74
 
 
 
75
  video_frames_tensor = vr.get_batch(indices)
76
 
 
77
  video_frames_tensor = video_frames_tensor.to(device)
78
  logger.info(f"Video frames successfully extracted and moved to device. Shape: {video_frames_tensor.shape}, Device: {video_frames_tensor.device}")
79
 
 
 
 
 
 
80
  inputs = processor(list(video_frames_tensor), return_tensors="pt")
81
 
 
82
  inputs = {k: v.to(device) for k, v in inputs.items()}
83
  logger.info(f"Frames processed by ImageProcessor and input tensors moved to device: {device}")
84
 
85
+ with torch.no_grad():
86
  outputs = model(**inputs)
87
  logits = outputs.logits
88
  predicted_class_idx = logits.argmax(-1).item()
89
 
90
  logger.info(f"Model inference complete. Predicted class index: {predicted_class_idx}")
91
 
 
92
  predicted_label = model.config.id2label[predicted_class_idx]
93
  logger.info(f"Predicted label: '{predicted_label}'")
94
 
 
98
  logger.error(f"Error during video classification for '{video_filepath}': {e}", exc_info=True)
99
  return f"Error during classification: {str(e)}. Please check the video format, ensure decord dependencies are met, or review server logs for more details."
100
 
 
101
  video_input_component = gr.Video(
102
  label="Upload Crime Video",
 
 
103
  )
104
  text_output_component = gr.Textbox(
105
  label="Classification Result"
106
  )
107
 
 
 
 
108
  example_video_paths = [
 
 
 
109
  ]
110
 
 
111
  iface = gr.Interface(
112
  fn=classify_video,
113
  inputs=video_input_component,
114
  outputs=text_output_component,
115
  title="Video Crime Classification (GPU Accelerated)",
116
  description="Upload a video to classify the type of crime depicted using a VideoMAE model fine-tuned on UCF-Crime. Processing runs on GPU if available.",
117
+ examples=example_video_paths,
118
+ allow_flagging="never"
119
  )
120
 
 
121
  if __name__ == "__main__":
122
  logger.info("Starting Gradio application...")
123
+ iface.launch(server_name="0.0.0.0")