Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import cv2 | |
| import os | |
| import gradio as gr | |
| from PIL import Image | |
| from transformers import ViTImageProcessor, ViTModel | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| MODEL_PATH = "Vit_LSTM.pth" | |
| NUM_CLASSES = 5 | |
| MAX_FRAMES = 16 | |
| class ViT_LSTM(nn.Module): | |
| def __init__(self, feature_dim=768, hidden_dim=512, num_classes=NUM_CLASSES): | |
| super(ViT_LSTM, self).__init__() | |
| self.lstm = nn.LSTM(feature_dim, hidden_dim, batch_first=True, num_layers=2, bidirectional=True) | |
| self.fc = nn.Linear(hidden_dim * 2, num_classes) | |
| self.dropout = nn.Dropout(0.3) | |
| def forward(self, x): | |
| lstm_out, _ = self.lstm(x) | |
| lstm_out = lstm_out[:, -1, :] | |
| out = self.dropout(lstm_out) | |
| out = self.fc(out) | |
| return out | |
| vit_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224") | |
| vit_model = ViTModel.from_pretrained("google/vit-base-patch16-224").to(DEVICE) | |
| model = ViT_LSTM() | |
| model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE)) | |
| model.to(DEVICE) | |
| model.eval() | |
| LABELS = ["BaseballPitch", "Basketball", "BenchPress", "Biking", "Billiards"] | |
| def extract_vit_features(video_path, max_frames=MAX_FRAMES): | |
| cap = cv2.VideoCapture(video_path) | |
| frames = [] | |
| frame_count = 0 | |
| while cap.isOpened() and frame_count < max_frames: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| frames.append(Image.fromarray(frame)) | |
| frame_count += 1 | |
| cap.release() | |
| if not frames: | |
| return None | |
| print(f"Extracted {len(frames)} frames from video.") | |
| inputs = vit_processor(images=frames, return_tensors="pt")["pixel_values"].to(DEVICE) | |
| with torch.no_grad(): | |
| features = vit_model(inputs).last_hidden_state.mean(dim=1) | |
| return features | |
| def predict_action(video_file): | |
| video_path = video_file.name | |
| print(f"Received video path: {video_path}") | |
| features = extract_vit_features(video_path) | |
| if features is None: | |
| return "No frames extracted, please upload a valid video." | |
| features = features.unsqueeze(0) | |
| with torch.no_grad(): | |
| output = model(features) | |
| predicted_class = torch.argmax(output, dim=1).item() | |
| return f"Predicted Action: {LABELS[predicted_class]}" | |
| # Gradio Interface | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Action Recognition with ViT-LSTM") | |
| gr.Markdown("Upload a short video to predict the action.") | |
| video_input = gr.File(label="Upload a video") | |
| output_text = gr.Textbox(label="Prediction") | |
| predict_btn = gr.Button("Predict Action") | |
| predict_btn.click(fn=predict_action, inputs=video_input, outputs=output_text) | |
| demo.launch() | |