File size: 8,143 Bytes
73ca82a 5e252c5 73ca82a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 |
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
from torchvision import models, transforms
import torch.nn as nn
import os
import json
import cv2
from PIL import Image
import gradio as gr
class MultimodalRiskBehaviorModel(nn.Module):
def __init__(self, text_model_name="bert-base-uncased", hidden_dim=512, dropout=0.3):
super(MultimodalRiskBehaviorModel, self).__init__()
# Text model using AutoModelForSequenceClassification
self.text_model_name = text_model_name
self.text_model = AutoModelForSequenceClassification.from_pretrained(text_model_name, num_labels=2)
# Visual model (ResNet50)
self.visual_model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
visual_feature_dim = self.visual_model.fc.in_features
self.visual_model.fc = nn.Identity()
# Fusion and classification layer setup
text_feature_dim = self.text_model.config.hidden_size
self.fc1 = nn.Linear(text_feature_dim + visual_feature_dim, hidden_dim)
self.dropout = nn.Dropout(dropout)
self.fc2 = nn.Linear(hidden_dim, 1)
def forward(self, encoding, frames):
input_ids = encoding['input_ids'].squeeze(1).to(device)
attention_mask = encoding['attention_mask'].squeeze(1).to(device)
# Extract text and visual features
text_features = self.text_model(input_ids=input_ids, attention_mask=attention_mask).logits
frames = frames.to(device)
batch_size, num_frames, channels, height, width = frames.size()
frames = frames.view(batch_size * num_frames, channels, height, width)
visual_features = self.visual_model(frames)
visual_features = visual_features.view(batch_size, num_frames, -1).mean(dim=1)
# Combine and classify
combined_features = torch.cat((text_features, visual_features), dim=1)
x = self.dropout(torch.relu(self.fc1(combined_features)))
output = torch.sigmoid(self.fc2(x))
return output
def save_pretrained(self, save_directory):
os.makedirs(save_directory, exist_ok=True)
torch.save(self.state_dict(), os.path.join(save_directory, 'pytorch_model.bin'))
config = {
"text_model_name": self.text_model_name,
"hidden_dim": self.fc1.out_features
}
with open(os.path.join(save_directory, 'config.json'), 'w') as f:
json.dump(config, f)
@classmethod
def from_pretrained(cls, load_directory, map_location=None):
if os.path.exists(load_directory):
config_path = os.path.join(load_directory, 'config.json')
state_dict_path = os.path.join(load_directory, 'pytorch_model.bin')
with open(config_path, 'r') as f:
config_dict = json.load(f)
model = cls(text_model_name=config_dict["text_model_name"], hidden_dim=config_dict["hidden_dim"])
state_dict = torch.load(state_dict_path, map_location=map_location)
model.load_state_dict(state_dict)
else:
hf_model = AutoModelForSequenceClassification.from_pretrained(load_directory, num_labels=2)
model = cls(text_model_name=hf_model.config.name_or_path, hidden_dim=hf_model.config.hidden_size)
model.text_model = hf_model
return model
tokenizer = AutoTokenizer.from_pretrained('Souha-BH/BERT_Resnet50')
model = MultimodalRiskBehaviorModel.from_pretrained('Souha-BH/BERT_Resnet50') # if cpu add arg map_location='cpu'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Function to load frames from a video
def load_frames_from_video(video_path, transform, num_frames=10):
cap = cv2.VideoCapture(video_path)
frames = []
frame_count = 0
while frame_count < num_frames: # Limit to a number of frames for efficiency
success, frame = cap.read()
if not success:
break
# Convert frame (NumPy array) to PIL image and apply transformations
frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
frame = transform(frame)
frames.append(frame)
frame_count += 1
cap.release()
# Stack frames and add batch dimension (1, num_frames, channels, height, width)
frames = torch.stack(frames)
frames = frames.unsqueeze(0) # Add batch dimension
return frames
def predict_video(model, video_path, text_input, tokenizer, transform):
try:
# Set model to evaluation mode
model.eval()
# Tokenize the text input
encoding = tokenizer(
text_input, padding='max_length', truncation=True, max_length=128, return_tensors='pt'
)
encoding = {key: val.to(device) for key, val in encoding.items()}
# Load frames from the video
frames = load_frames_from_video(video_path, transform)
frames = frames.to(device)
# Log input shapes and devices
print(f"Encoding device: {next(iter(encoding.values())).device}, Frames shape: {frames.shape}")
# Perform forward pass through the model
with torch.no_grad():
output = model(encoding, frames)
# Apply sigmoid to get probability, then threshold to get prediction
prediction = (output.squeeze(-1) > 0.5).float()
return prediction.item()
except Exception as e:
print(f"Prediction error: {e}")
return "Error during prediction"
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Define your video paths and captions
video_paths = [
'https://drive.google.com/uc?export=download&id=1iWq1q1LM-jmf4iZxOqZTw4FaIBekJowM',
'https://drive.google.com/uc?export=download&id=1_egBaC1HD2kIZgRRKsnCtsWG94vg1c7n',
'https://drive.google.com/uc?export=download&id=12cGxBEkfU5Q1Ezg2jRk6zGyn2hoR3JLj'
]
video_captions = [
"Everytime i start a diet كل مرة أحاول أبدأ ريجيم 😓 #dietmemes #funnyvideos #animetiktok",
"New sandwich from burger king 🍔👑 #mukbang #asmr #asmrmukbang #asmrsounds #eat #food #Foodie moe eats #yummy #cheese #chicken #burger #fries #burgerking @Burger King",
"all workout guides l!nked in bi0 // honestly huge moment 😂 I’ve been so focused on growing my upper body that this feels like it finally shows! shorts from @KEEPTHATPUMP #upperbody #upperbodyworkout #glutegains #glutegrowth #gluteexercise #workout #strengthtraining #gym #trending #fyp"
]
def predict_risk(video_index):
video_path = video_paths[video_index]
text_input = video_captions[video_index]
# Make prediction
prediction = predict_video(model, video_path, text_input, tokenizer, transform)
# Return the corresponding label
return "Risky Health Behavior" if prediction == 1 else "Not Risky Health Behavior"
# Interface setup
with gr.Blocks() as interface:
gr.Markdown("# Risk Behavior Prediction")
gr.Markdown("Select a video to classify its behavior as risky or not.")
# Input option selector
video_selector = gr.Radio(["Video 1", "Video 2", "Video 3"], label="Choose a Video")
# Use function to return URLs which are handled by the Gradio `gr.Video` component
def show_selected_video(choice):
idx = int(choice.split()[-1]) - 1
return video_paths[idx], f"**Caption:** {video_captions[idx]}"
video_player = gr.Video(width=320, height=240)
caption_box = gr.Markdown()
video_selector.change(
fn=show_selected_video,
inputs=video_selector,
outputs=[video_player, caption_box]
)
# Prediction button and output
predict_button = gr.Button("Predict Risk")
output_text = gr.Textbox(label="Prediction")
predict_button.click(
fn=lambda idx: predict_risk(int(idx.split()[-1]) - 1),
inputs=video_selector,
outputs=output_text
)
# Launch the app
interface.launch() |