File size: 2,279 Bytes
31ad805 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
from torchvision import transforms
import torch
from PIL import Image
from model import ImprovedEfficientViT
import os
import cv2
from mtcnn import MTCNN
def extract_faces(video_path, target_frames=20):
detector = MTCNN()
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
print(f"Error: Could not open video {video_path}")
return []
fps = cap.get(cv2.CAP_PROP_FPS)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
frame_interval = max(total_frames // target_frames, 1)
face_images = []
for i in range(0, total_frames, frame_interval):
cap.set(cv2.CAP_PROP_POS_FRAMES, i)
ret, frame = cap.read()
if not ret:
continue
rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
faces = detector.detect_faces(rgb_frame)
for face in faces:
if face['confidence'] < 0.9:
continue
x, y, w, h = face['box']
x, y = max(x, 0), max(y, 0)
face_img = rgb_frame[y:y+h, x:x+w]
if face_img.size == 0:
continue
face_img = cv2.resize(face_img, (224, 224))
face_images.append(face_img)
cap.release()
return face_images
from torchvision import transforms
transform_vedio=transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((224,224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5],std=[0.5])
])
def predict_vedio(video_path,model_vedio):
pred_list = []
prob_list=[]
faces = extract_faces(video_path, target_frames=20)
transformed_faces = [transform_vedio(face) for face in faces]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_vedio.to(device)
for face in transformed_faces:
face = face.to(device).unsqueeze(0)
with torch.no_grad():
logit = model_vedio(face)
prob = torch.sigmoid(logit)
pred = int(prob.item() > 0.5)
pred_list.append(pred)
prob_list.append(prob)
count=0
for ele in pred_list:
if ele==0:
count+=1
predicted_class=0 if count>3 else 1
return{
"class":predicted_class
}
|