import os import time import shutil import cv2 import numpy as np import torch import torch.nn as nn from torch.utils.data import Dataset, DataLoader from pathlib import Path from huggingface_hub import hf_hub_download from marlin_pytorch import Marlin # ─── Paths ──────────────────────────────────────────────────────────────────── BASE = os.path.dirname(os.path.abspath(__file__)) MARLIN_PATH = os.path.join(BASE, "marlin_vit_large_ytf.encoder.pt") LSTM_PATH = os.path.join(BASE, "best_combined_model_lstm.pt") # ─── Download MARLIN encoder from HF Hub if not present ────────────────────── # ─── Download MARLIN encoder from HF Hub if not present ────────────────────── if not os.path.exists(MARLIN_PATH): print("⬇️ Downloading MARLIN encoder from HuggingFace...") downloaded = hf_hub_download( repo_id="ControlNet/MARLIN", filename="marlin_vit_large_ytf.encoder.pt", ) shutil.copy(downloaded, MARLIN_PATH) print("✅ MARLIN encoder downloaded.") # ─── Download LSTM checkpoint from HF Hub if not present ───────────────────── if not os.path.exists(LSTM_PATH): print("⬇️ Downloading LSTM checkpoint from HuggingFace...") downloaded = hf_hub_download( repo_id="salal047/engagement-lstm", filename="best_combined_model_lstm .pt", ) shutil.copy(downloaded, LSTM_PATH) print("✅ LSTM checkpoint downloaded.") # ─── Device ─────────────────────────────────────────────────────────────────── device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"🖥️ Using device: {device}") # ─── Hyperparameters ────────────────────────────────────────────────────────── hidden_size = 512 dropout_rate = 0.5 num_layers = 1 num_classes = 3 # ─── LSTM Classifier ────────────────────────────────────────────────────────── class LSTMClassifier(nn.Module): def __init__(self, input_size, hidden_size, num_layers, dropout_rate): super().__init__() self.lstm = nn.LSTM( input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True, dropout=dropout_rate ) def forward(self, x): x = x.float() _, (hn, _) = self.lstm(x) return hn[-1] # ─── Combined Model ─────────────────────────────────────────────────────────── class CombinedModel(nn.Module): def __init__(self, lstm, mlp_classifier): super().__init__() self.LSTM_Model = lstm self.classifier = mlp_classifier def forward(self, features): x_out = self.LSTM_Model(features) logits = self.classifier(x_out) return logits # ─── Build models ───────────────────────────────────────────────────────────── model_face = LSTMClassifier( input_size=1024, hidden_size=hidden_size, num_layers=num_layers, dropout_rate=dropout_rate ).to(device) mlp_classifier = nn.Sequential( nn.Linear(hidden_size, 512), nn.ReLU(), nn.Dropout(0.5), nn.Linear(512, num_classes) ).to(device) # Load MARLIN encoder print("⬇️ Loading MARLIN model...") Marlin_Model = Marlin.from_file("marlin_vit_large_ytf", MARLIN_PATH) Marlin_Model.to(device) print("✅ MARLIN loaded.") # Load LSTM checkpoint combined_model = CombinedModel(model_face, mlp_classifier).to(device) print("⬇️ Loading LSTM checkpoint...") checkpoint = torch.load(LSTM_PATH, map_location=device, weights_only=True) combined_model.LSTM_Model.load_state_dict(checkpoint['model_face_state_dict']) combined_model.classifier.load_state_dict(checkpoint['mlp_classifier_state_dict']) combined_model.eval() print("✅ LSTM checkpoint loaded.") class_names = ["Not-Engaged", "Engaged", "Highly-Engaged"] # ─── Dataset ────────────────────────────────────────────────────────────────── class EngageNetDataset(Dataset): def __init__(self, folder_path): self.videos_path = list(Path(folder_path).glob("*.mp4")) def __len__(self): return len(self.videos_path) def __getitem__(self, idx): user_id = self.videos_path[idx].stem return str(self.videos_path[idx]), user_id # ─── Predict ────────────────────────────────────────────────────────────────── def predict(path, batch_size=8): print("______________________ PREDICT is CALLED ____________________") dataset = EngageNetDataset(path) loader = DataLoader(dataset, batch_size=batch_size, shuffle=False) combined_model.eval() combined_model.to(device) all_pred_names = [] all_user_ids = [] start_time = time.time() with torch.no_grad(): for batch_data, batch_ids in loader: batch_features = [] for video_path in batch_data: feat = Marlin_Model.extract_video(video_path) if isinstance(feat, np.ndarray): feat = torch.tensor(feat) batch_features.append(feat) batch_features = torch.stack(batch_features).to(device) logits = combined_model(batch_features) pred_indices = torch.argmax(logits, dim=1) batch_class_names = [class_names[i] for i in pred_indices.cpu().numpy()] all_pred_names.extend(batch_class_names) all_user_ids.extend(batch_ids) # Clean up processed videos for f in batch_data: try: os.remove(f) except Exception: pass print(f"⏱️ Predict time: {time.time() - start_time:.4f}s") return all_pred_names, all_user_ids