import sys import os import os.path as osp import gradio as gr import numpy as np import tempfile from PIL import Image, ImageOps import cv2 import matplotlib.pyplot as plt import io import base64 root_path = osp.abspath(osp.join(__file__, osp.pardir)) sys.path.append(root_path) from registry_utils import import_registered_modules from gradio_utils import ( is_image, is_video, extract_frames, resize_frame, CAM_METHODS, process_frames_gradio, ) import_registered_modules() def process_image_gradio(image, pupil_selection, tv_model, blink_detection): """ Process a single image and return results for Gradio interface. Args: image: PIL Image or numpy array pupil_selection: str - "left_pupil", "right_pupil", or "both" tv_model: str - "ResNet18" or "ResNet50" blink_detection: bool - whether to detect blinks Returns: tuple: (input_image, cam_overlay, diameter_text, results_plot) """ try: # Convert to PIL Image if needed if isinstance(image, np.ndarray): image = Image.fromarray(image) # Handle EXIF rotation image = ImageOps.exif_transpose(image) # Resize image image = resize_frame(image, max_width=640, max_height=480) # Process the image using Gradio-compatible function input_frames, output_frames, predicted_diameters = process_frames_gradio( input_imgs=[image], tv_model=tv_model, pupil_selection=pupil_selection, blink_detection=blink_detection, ) # Check if processing failed (empty results) if not input_frames or not output_frames or not predicted_diameters: error_msg = "Could not detect face/eyes in the image. Please try with a clearer image showing eyes." error_img = Image.new('RGB', (400, 200), 'white') return error_img, error_msg # Create visualization results = [] diameter_results = [] for eye_type in input_frames.keys(): input_img = input_frames[eye_type][-1] output_img = output_frames[eye_type][-1] diameter = predicted_diameters[eye_type][0] # Create side-by-side comparison fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6)) ax1.imshow(input_img) ax1.set_title(f"Input - {eye_type.replace('_', ' ').title()}") ax1.axis('off') ax2.imshow(output_img) ax2.set_title(f"CAM Overlay - {eye_type.replace('_', ' ').title()}") ax2.axis('off') plt.tight_layout() # Convert plot to image buf = io.BytesIO() plt.savefig(buf, format='png', dpi=150, bbox_inches='tight') buf.seek(0) plot_img = Image.open(buf) plt.close() results.append(plot_img) # Format diameter result if isinstance(diameter, str): diameter_results.append(f"{eye_type.replace('_', ' ').title()}: {diameter}") else: diameter_results.append(f"{eye_type.replace('_', ' ').title()}: {diameter:.2f} mm") # Combine results if multiple eyes if len(results) == 1: final_image = results[0] else: # Combine multiple eye results total_width = sum(img.width for img in results) max_height = max(img.height for img in results) final_image = Image.new('RGB', (total_width, max_height), 'white') x_offset = 0 for img in results: final_image.paste(img, (x_offset, 0)) x_offset += img.width diameter_text = "\n".join(diameter_results) return final_image, diameter_text except Exception as e: error_msg = f"Error processing image: {str(e)}" # Create error image error_img = Image.new('RGB', (400, 200), 'white') return error_img, error_msg def process_video_gradio(video_file, pupil_selection, tv_model, blink_detection): """ Process a video file and return results for Gradio interface. Args: video_file: file path or file object pupil_selection: str - "left_pupil", "right_pupil", or "both" tv_model: str - "ResNet18" or "ResNet50" blink_detection: bool - whether to detect blinks Returns: tuple: (results_plot, diameter_data, summary_text) """ try: # Handle video file if hasattr(video_file, 'name'): video_path = video_file.name else: video_path = video_file # Extract frames video_frames = extract_frames(video_path) if not video_frames: return None, "No frames extracted from video", "Error: Could not process video" # Resize frames resized_frames = [] for frame in video_frames: if isinstance(frame, np.ndarray): frame = Image.fromarray(frame) input_img = resize_frame(frame, max_width=640, max_height=480) resized_frames.append(input_img) # Process video frames using Gradio-compatible function input_frames, output_frames, predicted_diameters = process_frames_gradio( input_imgs=resized_frames, tv_model=tv_model, pupil_selection=pupil_selection, blink_detection=blink_detection, ) # Check if processing failed (empty results) if not input_frames or not output_frames or not predicted_diameters: error_msg = "Could not process video. MediaPipe may have issues in this environment." error_img = Image.new('RGB', (400, 200), 'white') return error_img, "", error_msg # Create results visualization fig, axes = plt.subplots(len(predicted_diameters), 1, figsize=(12, 6 * len(predicted_diameters))) if len(predicted_diameters) == 1: axes = [axes] summary_stats = [] for idx, (eye_type, diameters) in enumerate(predicted_diameters.items()): # Filter out non-numeric values (like "blink") numeric_diameters = [d for d in diameters if isinstance(d, (int, float))] frame_numbers = list(range(len(diameters))) # Plot diameter over time axes[idx].plot(frame_numbers, diameters, marker='o', markersize=2) axes[idx].set_title(f"Pupil Diameter Over Time - {eye_type.replace('_', ' ').title()}") axes[idx].set_xlabel("Frame Number") axes[idx].set_ylabel("Diameter (mm)") axes[idx].grid(True, alpha=0.3) # Calculate statistics if numeric_diameters: mean_diameter = np.mean(numeric_diameters) std_diameter = np.std(numeric_diameters) min_diameter = np.min(numeric_diameters) max_diameter = np.max(numeric_diameters) summary_stats.append(f"{eye_type.replace('_', ' ').title()}:") summary_stats.append(f" Mean: {mean_diameter:.2f} mm") summary_stats.append(f" Std: {std_diameter:.2f} mm") summary_stats.append(f" Min: {min_diameter:.2f} mm") summary_stats.append(f" Max: {max_diameter:.2f} mm") summary_stats.append("") plt.tight_layout() # Convert plot to image buf = io.BytesIO() plt.savefig(buf, format='png', dpi=150, bbox_inches='tight') buf.seek(0) plot_img = Image.open(buf) plt.close() # Create summary text summary_text = f"Processed {len(video_frames)} frames\n\n" + "\n".join(summary_stats) # Create CSV data for download csv_data = "Frame,Eye_Type,Diameter_mm\n" for eye_type, diameters in predicted_diameters.items(): for frame_idx, diameter in enumerate(diameters): csv_data += f"{frame_idx},{eye_type},{diameter}\n" # Clean up temporary files if they exist # (output_video_path not used in this implementation) return plot_img, csv_data, summary_text except Exception as e: error_msg = f"Error processing video: {str(e)}" error_img = Image.new('RGB', (400, 200), 'white') return error_img, "", error_msg def process_base64_media(data_url, pupil_selection, tv_model, blink_detection): """ Process base64 encoded media (images or videos). Args: data_url: str - Base64 data URL (e.g., "data:video/mp4;base64,...") pupil_selection: str - "left_pupil", "right_pupil", or "both" tv_model: str - "ResNet18" or "ResNet50" blink_detection: bool - whether to detect blinks Returns: tuple: (result_image, result_text) """ try: # Parse the data URL if not data_url.startswith('data:'): raise ValueError("Invalid data URL format") # Extract MIME type and base64 data header, base64_data = data_url.split(',', 1) mime_type = header.split(';')[0].replace('data:', '') # Decode base64 data import base64 file_data = base64.b64decode(base64_data) # Determine file type from MIME type if mime_type.startswith('video/'): # Handle video file_extension = mime_type.split('/')[-1] if file_extension == 'quicktime': file_extension = 'mov' elif file_extension == 'x-msvideo': file_extension = 'avi' # Check if it's a supported video format if not is_video(file_extension): from PIL import Image error_img = Image.new('RGB', (400, 200), 'white') return error_img, f"Unsupported video format: {file_extension}. Supported formats: mp4, avi, mov, mkv, webm, flv, wmv." # Create temporary file for video processing import tempfile import os with tempfile.NamedTemporaryFile(delete=False, suffix=f'.{file_extension}') as temp_file: temp_file.write(file_data) temp_file_path = temp_file.name try: # Process video using the temporary file plot_img, csv_data, summary_text = process_video_gradio(temp_file_path, pupil_selection, tv_model, blink_detection) combined_output = f"{summary_text}\n\n--- CSV Data ---\n{csv_data}" return plot_img, combined_output finally: # Clean up temporary file if os.path.exists(temp_file_path): os.unlink(temp_file_path) elif mime_type.startswith('image/'): # Handle image file_extension = mime_type.split('/')[-1] if file_extension == 'jpeg': file_extension = 'jpg' # Check if it's a supported image format if not is_image(file_extension): from PIL import Image error_img = Image.new('RGB', (400, 200), 'white') return error_img, f"Unsupported image format: {file_extension}. Supported formats: png, jpg, jpeg, bmp, tiff, webp." # Convert to PIL Image from PIL import Image import io image = Image.open(io.BytesIO(file_data)) return process_image_gradio(image, pupil_selection, tv_model, blink_detection) else: # Unknown MIME type from PIL import Image error_img = Image.new('RGB', (400, 200), 'white') return error_img, f"Unsupported media type: {mime_type}. Please provide a video or image file." except Exception as e: from PIL import Image error_img = Image.new('RGB', (400, 200), 'white') return error_img, f"Error processing base64 media: {str(e)}" def process_media_unified_with_fallback(text_input, file_input, pupil_selection, tv_model, blink_detection): """ Wrapper function that handles both text input (base64) and file input. Args: text_input: str - Base64 data URL or empty string file_input: File object or None pupil_selection: str - "left_pupil", "right_pupil", or "both" tv_model: str - "ResNet18" or "ResNet50" blink_detection: bool - whether to detect blinks Returns: tuple: (result_image, result_text) """ # Prioritize text input (base64) if provided if text_input and text_input.strip(): return process_media_unified(text_input.strip(), pupil_selection, tv_model, blink_detection) # Fallback to file input if no text input if file_input is not None: return process_media_unified(file_input, pupil_selection, tv_model, blink_detection) # No input provided from PIL import Image error_img = Image.new('RGB', (400, 200), 'white') return error_img, "No input provided. Please provide either a base64 data URL or upload a file." def process_media_unified(media_input, pupil_selection, tv_model, blink_detection): """ Unified processing function that handles both images and videos. Args: media_input: Either an image (PIL), video file path, or base64 data URL pupil_selection: str - "left_pupil", "right_pupil", or "both" tv_model: str - "ResNet18" or "ResNet50" blink_detection: bool - whether to detect blinks Returns: tuple: (result_image, result_text) """ try: # Handle None or empty input if media_input is None or (isinstance(media_input, str) and not media_input.strip()): from PIL import Image error_img = Image.new('RGB', (400, 200), 'white') return error_img, "No media input provided. Please upload an image or video or provide a base64 data URL." # Check if input is a base64 data URL (string input) if isinstance(media_input, str) and media_input.strip().startswith('data:'): return process_base64_media(media_input.strip(), pupil_selection, tv_model, blink_detection) # Check if input is an image or video file object if hasattr(media_input, 'name'): # It's a file object (video) file_path = media_input.name # Extract file extension from the path import os file_extension = os.path.splitext(file_path)[1][1:] # Remove the dot if is_video(file_extension): plot_img, csv_data, summary_text = process_video_gradio(media_input, pupil_selection, tv_model, blink_detection) combined_output = f"{summary_text}\n\n--- CSV Data ---\n{csv_data}" return plot_img, combined_output elif is_image(file_extension): # Convert file to PIL Image from PIL import Image image = Image.open(file_path) return process_image_gradio(image, pupil_selection, tv_model, blink_detection) else: # Unknown file type from PIL import Image error_img = Image.new('RGB', (400, 200), 'white') return error_img, f"Unsupported file type: {file_extension}. Supported video formats: mp4, avi, mov, mkv, webm, flv, wmv. Supported image formats: png, jpg, jpeg, bmp, tiff, webp." else: # It's a PIL Image or other format if media_input is not None: return process_image_gradio(media_input, pupil_selection, tv_model, blink_detection) else: # Fallback for any other None case from PIL import Image error_img = Image.new('RGB', (400, 200), 'white') return error_img, "Invalid media input format." except Exception as e: error_msg = f"Error processing media: {str(e)}" from PIL import Image error_img = Image.new('RGB', (400, 200), 'white') return error_img, error_msg def create_gradio_interface(): """Create and configure the Gradio interface with proper API support.""" print("🔧 Creating Gradio interface...") try: # Create a unified interface that can handle both images and videos with gr.Blocks(title="👁️ PupilSense 👁️🕵️‍♂️") as demo: gr.Markdown("# 👁️ PupilSense - Pupil Diameter Analysis") gr.Markdown("Upload an image or video to estimate pupil diameter using deep learning models.") with gr.Tab("Image Processing"): with gr.Row(): with gr.Column(): image_input = gr.Image(type="pil", label="Upload Image") image_pupil_selection = gr.Dropdown( ["left_pupil", "right_pupil", "both"], value="both", label="Pupil Selection" ) image_model = gr.Dropdown( ["ResNet18", "ResNet50"], value="ResNet18", label="Model" ) image_blink_detection = gr.Checkbox(value=True, label="Detect Blinks") image_submit = gr.Button("Process Image", variant="primary") with gr.Column(): image_output = gr.Image(label="Results") image_text_output = gr.Textbox(label="Pupil Diameter Results", lines=5) image_submit.click( fn=process_image_simple, inputs=[image_input, image_pupil_selection, image_model, image_blink_detection], outputs=[image_output, image_text_output] ) with gr.Tab("Video Processing"): with gr.Row(): with gr.Column(): video_input = gr.Video(label="Upload Video") video_pupil_selection = gr.Dropdown( ["left_pupil", "right_pupil", "both"], value="both", label="Pupil Selection" ) video_model = gr.Dropdown( ["ResNet18", "ResNet50"], value="ResNet18", label="Model" ) video_blink_detection = gr.Checkbox(value=True, label="Detect Blinks") video_submit = gr.Button("Process Video", variant="primary") with gr.Column(): video_output = gr.Image(label="Diameter Analysis") video_text_output = gr.Textbox(label="Summary Statistics", lines=10) video_submit.click( fn=process_video_simple, inputs=[video_input, video_pupil_selection, video_model, video_blink_detection], outputs=[video_output, video_text_output] ) # Add a unified API endpoint that can handle both images and videos with gr.Tab("API Testing"): gr.Markdown("### API Endpoint for External Access") gr.Markdown("This endpoint can process both images and videos programmatically.") gr.Markdown("**Input Format**: Accepts base64 data URLs (e.g., `data:video/mp4;base64,...`) or file uploads.") with gr.Row(): with gr.Column(): api_media_input = gr.Textbox( label="Media Input", placeholder="Paste base64 data URL (data:video/mp4;base64,...) or upload file below", lines=3 ) api_file_input = gr.File(label="Or Upload Image/Video File (optional)", visible=True) api_pupil_selection = gr.Dropdown( ["left_pupil", "right_pupil", "both"], value="both", label="Pupil Selection" ) api_model = gr.Dropdown( ["ResNet18", "ResNet50"], value="ResNet18", label="Model" ) api_blink_detection = gr.Checkbox(value=True, label="Detect Blinks") api_submit = gr.Button("Process Media", variant="primary") with gr.Column(): api_output = gr.Image(label="Results") api_text_output = gr.Textbox(label="Analysis Results", lines=10) api_submit.click( fn=process_media_unified, inputs=[api_media_input, api_pupil_selection, api_model, api_blink_detection], outputs=[api_output, api_text_output] ) print("✅ Gradio interface created successfully") return demo except Exception as e: print(f"❌ Error creating Gradio interface: {e}") import traceback traceback.print_exc() raise e def process_image_simple(image, pupil_selection, tv_model, blink_detection): """Simplified image processing function for gr.Interface.""" result_image, result_text = process_image_gradio(image, pupil_selection, tv_model, blink_detection) return result_image, result_text def process_video_simple(video_file, pupil_selection, tv_model, blink_detection): """Simplified video processing function for gr.Interface.""" plot_img, csv_data, summary_text = process_video_gradio(video_file, pupil_selection, tv_model, blink_detection) # Combine summary and CSV data for single output combined_output = f"{summary_text}\n\n--- CSV Data ---\n{csv_data}" return plot_img, combined_output if __name__ == "__main__": demo = create_gradio_interface() demo.launch()