Spaces:
Sleeping
Sleeping
| import os | |
| import time | |
| import shutil | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from torch.utils.data import Dataset, DataLoader | |
| from pathlib import Path | |
| from huggingface_hub import hf_hub_download | |
| from marlin_pytorch import Marlin | |
| # βββ Paths ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| BASE = os.path.dirname(os.path.abspath(__file__)) | |
| MARLIN_PATH = os.path.join(BASE, "marlin_vit_large_ytf.encoder.pt") | |
| LSTM_PATH = os.path.join(BASE, "best_combined_model_lstm.pt") | |
| # βββ Download MARLIN encoder from HF Hub if not present ββββββββββββββββββββββ | |
| # βββ Download MARLIN encoder from HF Hub if not present ββββββββββββββββββββββ | |
| if not os.path.exists(MARLIN_PATH): | |
| print("β¬οΈ Downloading MARLIN encoder from HuggingFace...") | |
| downloaded = hf_hub_download( | |
| repo_id="ControlNet/MARLIN", | |
| filename="marlin_vit_large_ytf.encoder.pt", | |
| ) | |
| shutil.copy(downloaded, MARLIN_PATH) | |
| print("β MARLIN encoder downloaded.") | |
| # βββ Download LSTM checkpoint from HF Hub if not present βββββββββββββββββββββ | |
| if not os.path.exists(LSTM_PATH): | |
| print("β¬οΈ Downloading LSTM checkpoint from HuggingFace...") | |
| downloaded = hf_hub_download( | |
| repo_id="salal047/engagement-lstm", | |
| filename="best_combined_model_lstm .pt", | |
| ) | |
| shutil.copy(downloaded, LSTM_PATH) | |
| print("β LSTM checkpoint downloaded.") | |
| # βββ Device βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"π₯οΈ Using device: {device}") | |
| # βββ Hyperparameters ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| hidden_size = 512 | |
| dropout_rate = 0.5 | |
| num_layers = 1 | |
| num_classes = 3 | |
| # βββ LSTM Classifier ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class LSTMClassifier(nn.Module): | |
| def __init__(self, input_size, hidden_size, num_layers, dropout_rate): | |
| super().__init__() | |
| self.lstm = nn.LSTM( | |
| input_size=input_size, | |
| hidden_size=hidden_size, | |
| num_layers=num_layers, | |
| batch_first=True, | |
| dropout=dropout_rate | |
| ) | |
| def forward(self, x): | |
| x = x.float() | |
| _, (hn, _) = self.lstm(x) | |
| return hn[-1] | |
| # βββ Combined Model βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class CombinedModel(nn.Module): | |
| def __init__(self, lstm, mlp_classifier): | |
| super().__init__() | |
| self.LSTM_Model = lstm | |
| self.classifier = mlp_classifier | |
| def forward(self, features): | |
| x_out = self.LSTM_Model(features) | |
| logits = self.classifier(x_out) | |
| return logits | |
| # βββ Build models βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| model_face = LSTMClassifier( | |
| input_size=1024, | |
| hidden_size=hidden_size, | |
| num_layers=num_layers, | |
| dropout_rate=dropout_rate | |
| ).to(device) | |
| mlp_classifier = nn.Sequential( | |
| nn.Linear(hidden_size, 512), | |
| nn.ReLU(), | |
| nn.Dropout(0.5), | |
| nn.Linear(512, num_classes) | |
| ).to(device) | |
| # Load MARLIN encoder | |
| print("β¬οΈ Loading MARLIN model...") | |
| Marlin_Model = Marlin.from_file("marlin_vit_large_ytf", MARLIN_PATH) | |
| Marlin_Model.to(device) | |
| print("β MARLIN loaded.") | |
| # Load LSTM checkpoint | |
| combined_model = CombinedModel(model_face, mlp_classifier).to(device) | |
| print("β¬οΈ Loading LSTM checkpoint...") | |
| checkpoint = torch.load(LSTM_PATH, map_location=device, weights_only=True) | |
| combined_model.LSTM_Model.load_state_dict(checkpoint['model_face_state_dict']) | |
| combined_model.classifier.load_state_dict(checkpoint['mlp_classifier_state_dict']) | |
| combined_model.eval() | |
| print("β LSTM checkpoint loaded.") | |
| class_names = ["Not-Engaged", "Engaged", "Highly-Engaged"] | |
| # βββ Dataset ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class EngageNetDataset(Dataset): | |
| def __init__(self, folder_path): | |
| self.videos_path = list(Path(folder_path).glob("*.mp4")) | |
| def __len__(self): | |
| return len(self.videos_path) | |
| def __getitem__(self, idx): | |
| user_id = self.videos_path[idx].stem | |
| return str(self.videos_path[idx]), user_id | |
| # βββ Predict ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def predict(path, batch_size=8): | |
| print("______________________ PREDICT is CALLED ____________________") | |
| dataset = EngageNetDataset(path) | |
| loader = DataLoader(dataset, batch_size=batch_size, shuffle=False) | |
| combined_model.eval() | |
| combined_model.to(device) | |
| all_pred_names = [] | |
| all_user_ids = [] | |
| start_time = time.time() | |
| with torch.no_grad(): | |
| for batch_data, batch_ids in loader: | |
| batch_features = [] | |
| for video_path in batch_data: | |
| feat = Marlin_Model.extract_video(video_path) | |
| if isinstance(feat, np.ndarray): | |
| feat = torch.tensor(feat) | |
| batch_features.append(feat) | |
| batch_features = torch.stack(batch_features).to(device) | |
| logits = combined_model(batch_features) | |
| pred_indices = torch.argmax(logits, dim=1) | |
| batch_class_names = [class_names[i] for i in pred_indices.cpu().numpy()] | |
| all_pred_names.extend(batch_class_names) | |
| all_user_ids.extend(batch_ids) | |
| # Clean up processed videos | |
| for f in batch_data: | |
| try: | |
| os.remove(f) | |
| except Exception: | |
| pass | |
| print(f"β±οΈ Predict time: {time.time() - start_time:.4f}s") | |
| return all_pred_names, all_user_ids | |