rajendrakumarv's picture
Update app.py (#7)
c224e46 verified
import os
# Configure Hugging Face caches to use the writable /cache volume in Spaces
os.environ["HF_HOME"] = "/cache"
os.environ["TRANSFORMERS_CACHE"] = "/cache"
os.environ["HF_DATASETS_CACHE"] = "/cache"
from flask import Flask, request, jsonify
import numpy as np
import torch
import av
import cv2
import tempfile
import shutil
import logging
from transformers import VideoMAEForVideoClassification, VideoMAEImageProcessor
from PIL import Image
from torchvision.transforms import Compose, Resize, ToTensor
# Initialize Flask app
app = Flask(__name__)
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Globals for model, processor, and transforms
device = "cuda" if torch.cuda.is_available() else "cpu"
model = None
processor = None
transform = None
def load_model():
"""Load the model and processor into globals"""
global model, processor, transform
if model is None:
model_name = "OPear/videomae-large-finetuned-UCF-Crime"
logger.info(f"Loading model {model_name} on device {device}")
# Downloads will go to /cache automatically
model = VideoMAEForVideoClassification.from_pretrained(model_name).to(device)
processor = VideoMAEImageProcessor.from_pretrained(model_name)
transform = Compose([
Resize((224, 224)),
ToTensor(),
])
logger.info("Model and processor loaded successfully")
return model, processor, transform
def sample_frame_indices(clip_len=16, frame_sample_rate=1, seg_len=0):
"""Uniformly sample exactly 16 frame indices from a clip"""
if seg_len <= clip_len:
indices = np.linspace(0, seg_len - 1, num=clip_len, dtype=int)
else:
end_idx = np.random.randint(clip_len, seg_len)
start_idx = max(0, end_idx - clip_len)
indices = np.linspace(start_idx, end_idx - 1, num=clip_len, dtype=int)
return np.clip(indices, 0, seg_len - 1)
def process_video(video_path):
"""Extract 16 uniformly-sampled frames from the video"""
try:
container = av.open(video_path)
video_stream = container.streams.video[0]
seg_len = video_stream.frames if video_stream.frames > 0 else int(
cv2.VideoCapture(video_path).get(cv2.CAP_PROP_FRAME_COUNT)
)
except Exception as e:
logger.error(f"Error opening video: {e}")
return None, None
indices = sample_frame_indices(clip_len=16, seg_len=seg_len)
frames = []
# Try PyAV decode
try:
container.seek(0)
for i, frame in enumerate(container.decode(video=0)):
if i > indices[-1]:
break
if i in indices:
frames.append(frame.to_ndarray(format="rgb24"))
except Exception as e:
logger.warning(f"PyAV decoding failed, falling back to OpenCV: {e}")
# Fallback to OpenCV if necessary
if len(frames) < len(indices):
cap = cv2.VideoCapture(video_path)
for i in indices:
cap.set(cv2.CAP_PROP_POS_FRAMES, i)
ret, frame = cap.read()
if ret:
frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
cap.release()
if len(frames) != 16:
logger.error(f"Expected 16 frames, got {len(frames)}")
return None, None
return np.stack(frames), indices
def predict_video(frames):
"""Run inference on a stack of 16 frames"""
model, processor, transform = load_model()
video_tensor = torch.stack([transform(Image.fromarray(f)) for f in frames])
video_tensor = video_tensor.unsqueeze(0)
inputs = processor(list(video_tensor[0]), return_tensors="pt", do_rescale=False)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
pred_id = logits.argmax(-1).item()
return model.config.id2label.get(pred_id, "Unknown")
@app.route('/classify-video', methods=['POST'])
def classify_video():
if 'video' not in request.files:
return jsonify({'error': 'No video file provided'}), 400
file = request.files['video']
if file.filename == '':
return jsonify({'error': 'Empty filename'}), 400
temp_dir = tempfile.mkdtemp()
path = os.path.join(temp_dir, file.filename)
try:
file.save(path)
frames, _ = process_video(path)
if frames is None:
return jsonify({'error': 'Failed to extract frames'}), 400
prediction = predict_video(frames)
return jsonify({'prediction': prediction})
except Exception as e:
logger.exception(f"Error during processing: {e}")
return jsonify({'error': str(e)}), 500
finally:
shutil.rmtree(temp_dir, ignore_errors=True)
@app.route('/health', methods=['GET'])
def health_check():
return jsonify({'status': 'healthy'}), 200
if __name__ == '__main__':
# Preload model on startup
logger.info("Starting application and loading model...")
load_model()
port = int(os.environ.get('PORT', 7860))
app.run(host='0.0.0.0', port=port, debug=False)