Spaces:
Sleeping
Sleeping
| import base64 | |
| from io import BytesIO | |
| import io | |
| import os | |
| import sys | |
| import cv2 | |
| from matplotlib import pyplot as plt | |
| import numpy as np | |
| import torch | |
| import tempfile | |
| from PIL import Image | |
| from torchvision.transforms.functional import to_pil_image | |
| from torchvision import transforms | |
| from PIL import ImageOps | |
| import os.path as osp | |
| from torchcam.methods import CAM | |
| from torchcam import methods as torchcam_methods | |
| from torchcam.utils import overlay_mask | |
| root_path = osp.abspath(osp.join(__file__, osp.pardir)) | |
| sys.path.append(root_path) | |
| from preprocessing.dataset_creation import EyeDentityDatasetCreation | |
| from utils import get_model | |
| CAM_METHODS = ["CAM"] | |
| def load_model(model_configs, device="cpu"): | |
| """Loads the pre-trained model.""" | |
| model_path = os.path.join(root_path, model_configs["model_path"]) | |
| model_dict = torch.load(model_path, map_location=device) | |
| model = get_model(model_configs=model_configs) | |
| model.load_state_dict(model_dict) | |
| model = model.to(device).eval() | |
| return model | |
| def extract_frames(video_path): | |
| """Extracts frames from a video file.""" | |
| import os | |
| # Debug: Check if file exists and get info | |
| print(f"π DEBUG: Attempting to extract frames from: {video_path}") | |
| print(f"π DEBUG: File exists: {os.path.exists(video_path)}") | |
| if os.path.exists(video_path): | |
| file_size = os.path.getsize(video_path) | |
| print(f"π DEBUG: File size: {file_size} bytes") | |
| print(f"π DEBUG: File permissions: {oct(os.stat(video_path).st_mode)}") | |
| else: | |
| print(f"β DEBUG: File does not exist at path: {video_path}") | |
| return [] | |
| # Debug: Try to open with OpenCV | |
| print(f"π DEBUG: Creating VideoCapture object...") | |
| vidcap = cv2.VideoCapture(video_path) | |
| # Debug: Check if VideoCapture opened successfully | |
| is_opened = vidcap.isOpened() | |
| print(f"π DEBUG: VideoCapture opened successfully: {is_opened}") | |
| if not is_opened: | |
| print(f"β DEBUG: Failed to open video with OpenCV") | |
| vidcap.release() | |
| return [] | |
| # Debug: Get video properties | |
| frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| fps = vidcap.get(cv2.CAP_PROP_FPS) | |
| width = int(vidcap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| height = int(vidcap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| print(f"π DEBUG: Video properties - Frames: {frame_count}, FPS: {fps}, Size: {width}x{height}") | |
| frames = [] | |
| frame_index = 0 | |
| success, image = vidcap.read() | |
| print(f"π DEBUG: First frame read success: {success}") | |
| while success: | |
| image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| frames.append(image_rgb) | |
| success, image = vidcap.read() | |
| frame_index += 1 | |
| # Debug: Log progress every 10 frames | |
| if frame_index % 10 == 0: | |
| print(f"π DEBUG: Extracted {frame_index} frames so far...") | |
| vidcap.release() | |
| print(f"β DEBUG: Successfully extracted {len(frames)} frames from video") | |
| return frames | |
| def resize_frame(frame, max_width=640, max_height=480): | |
| """Resizes a frame while maintaining aspect ratio.""" | |
| if isinstance(frame, np.ndarray): | |
| frame = Image.fromarray(frame) | |
| # Calculate the scaling factor | |
| width, height = frame.size | |
| scale_w = max_width / width | |
| scale_h = max_height / height | |
| scale = min(scale_w, scale_h) | |
| # Resize the frame | |
| new_width = int(width * scale) | |
| new_height = int(height * scale) | |
| return frame.resize((new_width, new_height), Image.Resampling.LANCZOS) | |
| def is_image(file_extension): | |
| """Check if file extension is an image format.""" | |
| return file_extension.lower() in ["png", "jpg", "jpeg", "bmp", "tiff", "webp"] | |
| def is_video(file_extension): | |
| """Check if file extension is a video format.""" | |
| return file_extension.lower() in ["mp4", "avi", "mov", "mkv", "webm", "flv", "wmv"] | |
| def get_configs(blink_detection=False): | |
| """Get configuration for feature extraction.""" | |
| upscale = "-" | |
| upscale_method_or_model = "-" | |
| if upscale == "-": | |
| sr_configs = None | |
| else: | |
| sr_configs = { | |
| "method": upscale_method_or_model, | |
| "params": {"upscale": upscale}, | |
| } | |
| config_file = { | |
| "sr_configs": sr_configs, | |
| "feature_extraction_configs": { | |
| "blink_detection": blink_detection, | |
| "upscale": upscale, | |
| "extraction_library": "mediapipe", | |
| }, | |
| } | |
| return config_file | |
| def setup_gradio(pupil_selection, tv_model): | |
| """Setup models and data structures for Gradio processing.""" | |
| left_pupil_model = None | |
| left_pupil_cam_extractor = None | |
| right_pupil_model = None | |
| right_pupil_cam_extractor = None | |
| output_frames = {} | |
| input_frames = {} | |
| predicted_diameters = {} | |
| if pupil_selection == "both": | |
| selected_eyes = ["left_eye", "right_eye"] | |
| elif pupil_selection == "left_pupil": | |
| selected_eyes = ["left_eye"] | |
| elif pupil_selection == "right_pupil": | |
| selected_eyes = ["right_eye"] | |
| for eye_type in selected_eyes: | |
| model_configs = { | |
| "model_path": root_path + f"/pre_trained_models/{tv_model}/{eye_type}.pt", | |
| "registered_model_name": tv_model, | |
| "num_classes": 1, | |
| } | |
| if eye_type == "left_eye": | |
| left_pupil_model = load_model(model_configs) | |
| left_pupil_cam_extractor = None | |
| else: | |
| right_pupil_model = load_model(model_configs) | |
| right_pupil_cam_extractor = None | |
| output_frames[eye_type] = [] | |
| input_frames[eye_type] = [] | |
| predicted_diameters[eye_type] = [] | |
| return ( | |
| selected_eyes, | |
| input_frames, | |
| output_frames, | |
| predicted_diameters, | |
| left_pupil_model, | |
| left_pupil_cam_extractor, | |
| right_pupil_model, | |
| right_pupil_cam_extractor, | |
| ) | |
| def process_frames_gradio(input_imgs, tv_model, pupil_selection, blink_detection=False): | |
| """ | |
| Process frames without Streamlit dependencies. | |
| """ | |
| try: | |
| config_file = get_configs(blink_detection) | |
| ( | |
| selected_eyes, | |
| input_frames, | |
| output_frames, | |
| predicted_diameters, | |
| left_pupil_model, | |
| left_pupil_cam_extractor, | |
| right_pupil_model, | |
| right_pupil_cam_extractor, | |
| ) = setup_gradio(pupil_selection, tv_model) | |
| ds_creation = EyeDentityDatasetCreation( | |
| feature_extraction_configs=config_file["feature_extraction_configs"], | |
| sr_configs=config_file["sr_configs"], | |
| ) | |
| except Exception as e: | |
| print(f"Error in setup: {e}") | |
| # Return empty results if setup fails | |
| return {}, {}, {} | |
| preprocess_steps = [ | |
| transforms.Resize( | |
| [32, 64], | |
| interpolation=transforms.InterpolationMode.BICUBIC, | |
| antialias=True, | |
| ), | |
| transforms.ToTensor(), | |
| ] | |
| preprocess_function = transforms.Compose(preprocess_steps) | |
| for idx, input_img in enumerate(input_imgs): | |
| try: | |
| img = np.array(input_img) | |
| ds_results = ds_creation(img) | |
| except Exception as e: | |
| print(f"Error in MediaPipe processing for frame {idx}: {e}") | |
| ds_results = None | |
| left_eye = None | |
| right_eye = None | |
| blinked = False | |
| if ds_results is not None and "face" in ds_results: | |
| has_face = True | |
| else: | |
| has_face = False | |
| if has_face and ds_results is not None: | |
| if blink_detection and "blinks" in ds_results: | |
| blinked = ds_results["blinks"]["blinked"] | |
| if not blinked and "eyes" in ds_results: | |
| if "left_eye" in ds_results["eyes"] and ds_results["eyes"]["left_eye"] is not None: | |
| left_eye_img = to_pil_image(ds_results["eyes"]["left_eye"]) | |
| input_img_tensor = preprocess_function(left_eye_img) | |
| input_img_tensor = input_img_tensor.unsqueeze(0) | |
| if pupil_selection in ["left_pupil", "both"]: | |
| left_eye = input_img_tensor | |
| if "right_eye" in ds_results["eyes"] and ds_results["eyes"]["right_eye"] is not None: | |
| right_eye_img = to_pil_image(ds_results["eyes"]["right_eye"]) | |
| input_img_tensor = preprocess_function(right_eye_img) | |
| input_img_tensor = input_img_tensor.unsqueeze(0) | |
| if pupil_selection in ["right_pupil", "both"]: | |
| right_eye = input_img_tensor | |
| for eye_type in selected_eyes: | |
| if blinked: | |
| if left_eye is not None and eye_type == "left_eye": | |
| _, height, width = left_eye.squeeze(0).shape | |
| input_image_pil = to_pil_image(left_eye.squeeze(0)) | |
| elif right_eye is not None and eye_type == "right_eye": | |
| _, height, width = right_eye.squeeze(0).shape | |
| input_image_pil = to_pil_image(right_eye.squeeze(0)) | |
| else: | |
| # Create a default black image if no eye detected | |
| input_image_pil = Image.new('RGB', (64, 32), 'black') | |
| height, width = 32, 64 | |
| input_img_np = np.array(input_image_pil) | |
| zeros_img = to_pil_image(np.zeros((height, width, 3), dtype=np.uint8)) | |
| output_img_np = np.array(zeros_img) | |
| predicted_diameter = "blink" | |
| else: | |
| if left_eye is not None and eye_type == "left_eye": | |
| if left_pupil_cam_extractor is None: | |
| if tv_model == "ResNet18": | |
| target_layer = left_pupil_model.resnet.layer4[-1].conv2 | |
| elif tv_model == "ResNet50": | |
| target_layer = left_pupil_model.resnet.layer4[-1].conv3 | |
| else: | |
| raise Exception(f"No target layer available for selected model: {tv_model}") | |
| left_pupil_cam_extractor = torchcam_methods.__dict__["CAM"]( | |
| left_pupil_model, | |
| target_layer=target_layer, | |
| fc_layer=left_pupil_model.resnet.fc, | |
| input_shape=left_eye.shape, | |
| ) | |
| output = left_pupil_model(left_eye) | |
| predicted_diameter = output[0].item() | |
| act_maps = left_pupil_cam_extractor(0, output) | |
| activation_map = act_maps[0] if len(act_maps) == 1 else left_pupil_cam_extractor.fuse_cams(act_maps) | |
| input_image_pil = to_pil_image(left_eye.squeeze(0)) | |
| elif right_eye is not None and eye_type == "right_eye": | |
| if right_pupil_cam_extractor is None: | |
| if tv_model == "ResNet18": | |
| target_layer = right_pupil_model.resnet.layer4[-1].conv2 | |
| elif tv_model == "ResNet50": | |
| target_layer = right_pupil_model.resnet.layer4[-1].conv3 | |
| else: | |
| raise Exception(f"No target layer available for selected model: {tv_model}") | |
| right_pupil_cam_extractor = torchcam_methods.__dict__["CAM"]( | |
| right_pupil_model, | |
| target_layer=target_layer, | |
| fc_layer=right_pupil_model.resnet.fc, | |
| input_shape=right_eye.shape, | |
| ) | |
| output = right_pupil_model(right_eye) | |
| predicted_diameter = output[0].item() | |
| act_maps = right_pupil_cam_extractor(0, output) | |
| activation_map = ( | |
| act_maps[0] if len(act_maps) == 1 else right_pupil_cam_extractor.fuse_cams(act_maps) | |
| ) | |
| input_image_pil = to_pil_image(right_eye.squeeze(0)) | |
| else: | |
| # No eye detected, create default values | |
| input_image_pil = Image.new('RGB', (64, 32), 'black') | |
| predicted_diameter = "no_eye_detected" | |
| output_img_np = np.array(input_image_pil) | |
| input_frames[eye_type].append(np.array(input_image_pil)) | |
| output_frames[eye_type].append(output_img_np) | |
| predicted_diameters[eye_type].append(predicted_diameter) | |
| continue | |
| # Create CAM overlay | |
| activation_map_pil = to_pil_image(activation_map, mode="F") | |
| result = overlay_mask(input_image_pil, activation_map_pil, alpha=0.5) | |
| input_img_np = np.array(input_image_pil) | |
| output_img_np = np.array(result) | |
| input_frames[eye_type].append(input_img_np) | |
| output_frames[eye_type].append(output_img_np) | |
| predicted_diameters[eye_type].append(predicted_diameter) | |
| return input_frames, output_frames, predicted_diameters | |