| import cv2 |
| import torch |
| import numpy as np |
| from PIL import Image |
| from torchvision import models, transforms |
| from config import DEVICE, FRAME_RATE |
| from tqdm import tqdm |
| from services.model_loader import batch_inference |
|
|
| |
| from torchvision.models import GoogLeNet_Weights |
| weights = GoogLeNet_Weights.DEFAULT |
| googlenet = models.googlenet(weights=weights).to(DEVICE).eval() |
|
|
| feature_extractor = torch.nn.Sequential( |
| googlenet.conv1, |
| googlenet.maxpool1, |
| googlenet.conv2, |
| googlenet.conv3, |
| googlenet.maxpool2, |
| googlenet.inception3a, |
| googlenet.inception3b, |
| googlenet.maxpool3, |
| googlenet.inception4a, |
| googlenet.inception4b, |
| googlenet.inception4c, |
| googlenet.inception4d, |
| googlenet.inception4e, |
| googlenet.maxpool4, |
| googlenet.inception5a, |
| googlenet.inception5b, |
| googlenet.avgpool, |
| torch.nn.Flatten() |
| ) |
| feature_extractor = feature_extractor.eval() |
|
|
| transform = transforms.Compose([ |
| transforms.Resize((224, 224)), |
| transforms.ToTensor(), |
| transforms.Normalize( |
| mean=[0.485, 0.456, 0.406], |
| std=[0.229, 0.224, 0.225] |
| ) |
| ]) |
|
|
| def extract_features(video_path): |
| cap = cv2.VideoCapture(video_path) |
| frames = [] |
| indices = [] |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
| |
| print(f"Total frames in video: {total_frames}") |
|
|
| for idx in tqdm(range(total_frames)): |
| cap.set(cv2.CAP_PROP_POS_FRAMES, idx) |
| ret, frame = cap.read() |
| if not ret: |
| break |
|
|
| |
| frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) |
| frame = transform(frame) |
|
|
| frames.append(frame) |
| indices.append(idx) |
|
|
| cap.release() |
|
|
| frames = torch.stack(frames).to(DEVICE) |
| print("Features before GoogleNet extraction:", frames.shape) |
| frames = batch_inference(model=feature_extractor, input=frames, batch_size=32) |
| print("Features after GoogleNet extraction:", frames.shape) |
|
|
| return frames, indices |
|
|