Spaces:
Runtime error
Runtime error
File size: 5,110 Bytes
ab3e631 c224e46 ab3e631 c224e46 ab3e631 c224e46 ab3e631 c224e46 ab3e631 c224e46 ab3e631 c224e46 ab3e631 c224e46 ab3e631 c224e46 ab3e631 c224e46 ab3e631 c224e46 ab3e631 c224e46 ab3e631 c224e46 ab3e631 c224e46 ab3e631 c224e46 ab3e631 c224e46 ab3e631 c224e46 ab3e631 c224e46 ab3e631 c224e46 ab3e631 c224e46 ab3e631 c224e46 ab3e631 c224e46 ab3e631 c224e46 ab3e631 c224e46 ab3e631 c224e46 ab3e631 c224e46 ab3e631 c224e46 ab3e631 c224e46 ab3e631 c224e46 ab3e631 c224e46 ab3e631 c224e46 ab3e631 c224e46 ab3e631 c224e46 ab3e631 c224e46 ab3e631 c224e46 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 | import os
# Configure Hugging Face caches to use the writable /cache volume in Spaces
os.environ["HF_HOME"] = "/cache"
os.environ["TRANSFORMERS_CACHE"] = "/cache"
os.environ["HF_DATASETS_CACHE"] = "/cache"
from flask import Flask, request, jsonify
import numpy as np
import torch
import av
import cv2
import tempfile
import shutil
import logging
from transformers import VideoMAEForVideoClassification, VideoMAEImageProcessor
from PIL import Image
from torchvision.transforms import Compose, Resize, ToTensor
# Initialize Flask app
app = Flask(__name__)
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Globals for model, processor, and transforms
device = "cuda" if torch.cuda.is_available() else "cpu"
model = None
processor = None
transform = None
def load_model():
"""Load the model and processor into globals"""
global model, processor, transform
if model is None:
model_name = "OPear/videomae-large-finetuned-UCF-Crime"
logger.info(f"Loading model {model_name} on device {device}")
# Downloads will go to /cache automatically
model = VideoMAEForVideoClassification.from_pretrained(model_name).to(device)
processor = VideoMAEImageProcessor.from_pretrained(model_name)
transform = Compose([
Resize((224, 224)),
ToTensor(),
])
logger.info("Model and processor loaded successfully")
return model, processor, transform
def sample_frame_indices(clip_len=16, frame_sample_rate=1, seg_len=0):
"""Uniformly sample exactly 16 frame indices from a clip"""
if seg_len <= clip_len:
indices = np.linspace(0, seg_len - 1, num=clip_len, dtype=int)
else:
end_idx = np.random.randint(clip_len, seg_len)
start_idx = max(0, end_idx - clip_len)
indices = np.linspace(start_idx, end_idx - 1, num=clip_len, dtype=int)
return np.clip(indices, 0, seg_len - 1)
def process_video(video_path):
"""Extract 16 uniformly-sampled frames from the video"""
try:
container = av.open(video_path)
video_stream = container.streams.video[0]
seg_len = video_stream.frames if video_stream.frames > 0 else int(
cv2.VideoCapture(video_path).get(cv2.CAP_PROP_FRAME_COUNT)
)
except Exception as e:
logger.error(f"Error opening video: {e}")
return None, None
indices = sample_frame_indices(clip_len=16, seg_len=seg_len)
frames = []
# Try PyAV decode
try:
container.seek(0)
for i, frame in enumerate(container.decode(video=0)):
if i > indices[-1]:
break
if i in indices:
frames.append(frame.to_ndarray(format="rgb24"))
except Exception as e:
logger.warning(f"PyAV decoding failed, falling back to OpenCV: {e}")
# Fallback to OpenCV if necessary
if len(frames) < len(indices):
cap = cv2.VideoCapture(video_path)
for i in indices:
cap.set(cv2.CAP_PROP_POS_FRAMES, i)
ret, frame = cap.read()
if ret:
frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
cap.release()
if len(frames) != 16:
logger.error(f"Expected 16 frames, got {len(frames)}")
return None, None
return np.stack(frames), indices
def predict_video(frames):
"""Run inference on a stack of 16 frames"""
model, processor, transform = load_model()
video_tensor = torch.stack([transform(Image.fromarray(f)) for f in frames])
video_tensor = video_tensor.unsqueeze(0)
inputs = processor(list(video_tensor[0]), return_tensors="pt", do_rescale=False)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
pred_id = logits.argmax(-1).item()
return model.config.id2label.get(pred_id, "Unknown")
@app.route('/classify-video', methods=['POST'])
def classify_video():
if 'video' not in request.files:
return jsonify({'error': 'No video file provided'}), 400
file = request.files['video']
if file.filename == '':
return jsonify({'error': 'Empty filename'}), 400
temp_dir = tempfile.mkdtemp()
path = os.path.join(temp_dir, file.filename)
try:
file.save(path)
frames, _ = process_video(path)
if frames is None:
return jsonify({'error': 'Failed to extract frames'}), 400
prediction = predict_video(frames)
return jsonify({'prediction': prediction})
except Exception as e:
logger.exception(f"Error during processing: {e}")
return jsonify({'error': str(e)}), 500
finally:
shutil.rmtree(temp_dir, ignore_errors=True)
@app.route('/health', methods=['GET'])
def health_check():
return jsonify({'status': 'healthy'}), 200
if __name__ == '__main__':
# Preload model on startup
logger.info("Starting application and loading model...")
load_model()
port = int(os.environ.get('PORT', 7860))
app.run(host='0.0.0.0', port=port, debug=False)
|