Spaces:
Running
Running
| import sys | |
| import os | |
| import os.path as osp | |
| import gradio as gr | |
| import numpy as np | |
| import tempfile | |
| from PIL import Image, ImageOps | |
| import cv2 | |
| import matplotlib.pyplot as plt | |
| import io | |
| import base64 | |
| root_path = osp.abspath(osp.join(__file__, osp.pardir)) | |
| sys.path.append(root_path) | |
| from registry_utils import import_registered_modules | |
| from gradio_utils import ( | |
| is_image, | |
| is_video, | |
| extract_frames, | |
| resize_frame, | |
| CAM_METHODS, | |
| process_frames_gradio, | |
| ) | |
| import_registered_modules() | |
| def process_image_gradio(image, pupil_selection, tv_model, blink_detection): | |
| """ | |
| Process a single image and return results for Gradio interface. | |
| Args: | |
| image: PIL Image or numpy array | |
| pupil_selection: str - "left_pupil", "right_pupil", or "both" | |
| tv_model: str - "ResNet18" or "ResNet50" | |
| blink_detection: bool - whether to detect blinks | |
| Returns: | |
| tuple: (input_image, cam_overlay, diameter_text, results_plot) | |
| """ | |
| try: | |
| # Convert to PIL Image if needed | |
| if isinstance(image, np.ndarray): | |
| image = Image.fromarray(image) | |
| # Handle EXIF rotation | |
| image = ImageOps.exif_transpose(image) | |
| # Resize image | |
| image = resize_frame(image, max_width=640, max_height=480) | |
| # Process the image using Gradio-compatible function | |
| input_frames, output_frames, predicted_diameters = process_frames_gradio( | |
| input_imgs=[image], | |
| tv_model=tv_model, | |
| pupil_selection=pupil_selection, | |
| blink_detection=blink_detection, | |
| ) | |
| # Check if processing failed (empty results) | |
| if not input_frames or not output_frames or not predicted_diameters: | |
| error_msg = "Could not detect face/eyes in the image. Please try with a clearer image showing eyes." | |
| error_img = Image.new('RGB', (400, 200), 'white') | |
| return error_img, error_msg | |
| # Create visualization | |
| results = [] | |
| diameter_results = [] | |
| for eye_type in input_frames.keys(): | |
| input_img = input_frames[eye_type][-1] | |
| output_img = output_frames[eye_type][-1] | |
| diameter = predicted_diameters[eye_type][0] | |
| # Create side-by-side comparison | |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6)) | |
| ax1.imshow(input_img) | |
| ax1.set_title(f"Input - {eye_type.replace('_', ' ').title()}") | |
| ax1.axis('off') | |
| ax2.imshow(output_img) | |
| ax2.set_title(f"CAM Overlay - {eye_type.replace('_', ' ').title()}") | |
| ax2.axis('off') | |
| plt.tight_layout() | |
| # Convert plot to image | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', dpi=150, bbox_inches='tight') | |
| buf.seek(0) | |
| plot_img = Image.open(buf) | |
| plt.close() | |
| results.append(plot_img) | |
| # Format diameter result | |
| if isinstance(diameter, str): | |
| diameter_results.append(f"{eye_type.replace('_', ' ').title()}: {diameter}") | |
| else: | |
| diameter_results.append(f"{eye_type.replace('_', ' ').title()}: {diameter:.2f} mm") | |
| # Combine results if multiple eyes | |
| if len(results) == 1: | |
| final_image = results[0] | |
| else: | |
| # Combine multiple eye results | |
| total_width = sum(img.width for img in results) | |
| max_height = max(img.height for img in results) | |
| final_image = Image.new('RGB', (total_width, max_height), 'white') | |
| x_offset = 0 | |
| for img in results: | |
| final_image.paste(img, (x_offset, 0)) | |
| x_offset += img.width | |
| diameter_text = "\n".join(diameter_results) | |
| return final_image, diameter_text | |
| except Exception as e: | |
| error_msg = f"Error processing image: {str(e)}" | |
| # Create error image | |
| error_img = Image.new('RGB', (400, 200), 'white') | |
| return error_img, error_msg | |
| def process_video_gradio(video_file, pupil_selection, tv_model, blink_detection): | |
| """ | |
| Process a video file and return results for Gradio interface. | |
| Args: | |
| video_file: file path or file object | |
| pupil_selection: str - "left_pupil", "right_pupil", or "both" | |
| tv_model: str - "ResNet18" or "ResNet50" | |
| blink_detection: bool - whether to detect blinks | |
| Returns: | |
| tuple: (results_plot, diameter_data, summary_text) | |
| """ | |
| try: | |
| # Handle video file | |
| if hasattr(video_file, 'name'): | |
| video_path = video_file.name | |
| else: | |
| video_path = video_file | |
| # Extract frames | |
| video_frames = extract_frames(video_path) | |
| if not video_frames: | |
| return None, "No frames extracted from video", "Error: Could not process video" | |
| # Resize frames | |
| resized_frames = [] | |
| for frame in video_frames: | |
| if isinstance(frame, np.ndarray): | |
| frame = Image.fromarray(frame) | |
| input_img = resize_frame(frame, max_width=640, max_height=480) | |
| resized_frames.append(input_img) | |
| # Process video frames using Gradio-compatible function | |
| input_frames, output_frames, predicted_diameters = process_frames_gradio( | |
| input_imgs=resized_frames, | |
| tv_model=tv_model, | |
| pupil_selection=pupil_selection, | |
| blink_detection=blink_detection, | |
| ) | |
| # Check if processing failed (empty results) | |
| if not input_frames or not output_frames or not predicted_diameters: | |
| error_msg = "Could not process video. MediaPipe may have issues in this environment." | |
| error_img = Image.new('RGB', (400, 200), 'white') | |
| return error_img, "", error_msg | |
| # Create results visualization | |
| fig, axes = plt.subplots(len(predicted_diameters), 1, figsize=(12, 6 * len(predicted_diameters))) | |
| if len(predicted_diameters) == 1: | |
| axes = [axes] | |
| summary_stats = [] | |
| for idx, (eye_type, diameters) in enumerate(predicted_diameters.items()): | |
| # Filter out non-numeric values (like "blink") | |
| numeric_diameters = [d for d in diameters if isinstance(d, (int, float))] | |
| frame_numbers = list(range(len(diameters))) | |
| # Plot diameter over time | |
| axes[idx].plot(frame_numbers, diameters, marker='o', markersize=2) | |
| axes[idx].set_title(f"Pupil Diameter Over Time - {eye_type.replace('_', ' ').title()}") | |
| axes[idx].set_xlabel("Frame Number") | |
| axes[idx].set_ylabel("Diameter (mm)") | |
| axes[idx].grid(True, alpha=0.3) | |
| # Calculate statistics | |
| if numeric_diameters: | |
| mean_diameter = np.mean(numeric_diameters) | |
| std_diameter = np.std(numeric_diameters) | |
| min_diameter = np.min(numeric_diameters) | |
| max_diameter = np.max(numeric_diameters) | |
| summary_stats.append(f"{eye_type.replace('_', ' ').title()}:") | |
| summary_stats.append(f" Mean: {mean_diameter:.2f} mm") | |
| summary_stats.append(f" Std: {std_diameter:.2f} mm") | |
| summary_stats.append(f" Min: {min_diameter:.2f} mm") | |
| summary_stats.append(f" Max: {max_diameter:.2f} mm") | |
| summary_stats.append("") | |
| plt.tight_layout() | |
| # Convert plot to image | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', dpi=150, bbox_inches='tight') | |
| buf.seek(0) | |
| plot_img = Image.open(buf) | |
| plt.close() | |
| # Create summary text | |
| summary_text = f"Processed {len(video_frames)} frames\n\n" + "\n".join(summary_stats) | |
| # Create CSV data for download | |
| csv_data = "Frame,Eye_Type,Diameter_mm\n" | |
| for eye_type, diameters in predicted_diameters.items(): | |
| for frame_idx, diameter in enumerate(diameters): | |
| csv_data += f"{frame_idx},{eye_type},{diameter}\n" | |
| # Clean up temporary files if they exist | |
| # (output_video_path not used in this implementation) | |
| return plot_img, csv_data, summary_text | |
| except Exception as e: | |
| error_msg = f"Error processing video: {str(e)}" | |
| error_img = Image.new('RGB', (400, 200), 'white') | |
| return error_img, "", error_msg | |
| def process_base64_media(data_url, pupil_selection, tv_model, blink_detection): | |
| """ | |
| Process base64 encoded media (images or videos). | |
| Args: | |
| data_url: str - Base64 data URL (e.g., "data:video/mp4;base64,...") | |
| pupil_selection: str - "left_pupil", "right_pupil", or "both" | |
| tv_model: str - "ResNet18" or "ResNet50" | |
| blink_detection: bool - whether to detect blinks | |
| Returns: | |
| tuple: (result_image, result_text) | |
| """ | |
| try: | |
| # Parse the data URL | |
| if not data_url.startswith('data:'): | |
| raise ValueError("Invalid data URL format") | |
| # Extract MIME type and base64 data | |
| header, base64_data = data_url.split(',', 1) | |
| mime_type = header.split(';')[0].replace('data:', '') | |
| # Decode base64 data | |
| import base64 | |
| file_data = base64.b64decode(base64_data) | |
| # Determine file type from MIME type | |
| if mime_type.startswith('video/'): | |
| # Handle video | |
| file_extension = mime_type.split('/')[-1] | |
| if file_extension == 'quicktime': | |
| file_extension = 'mov' | |
| elif file_extension == 'x-msvideo': | |
| file_extension = 'avi' | |
| # Check if it's a supported video format | |
| if not is_video(file_extension): | |
| from PIL import Image | |
| error_img = Image.new('RGB', (400, 200), 'white') | |
| return error_img, f"Unsupported video format: {file_extension}. Supported formats: mp4, avi, mov, mkv, webm, flv, wmv." | |
| # Create temporary file for video processing | |
| import tempfile | |
| import os | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=f'.{file_extension}') as temp_file: | |
| temp_file.write(file_data) | |
| temp_file_path = temp_file.name | |
| try: | |
| # Process video using the temporary file | |
| plot_img, csv_data, summary_text = process_video_gradio(temp_file_path, pupil_selection, tv_model, blink_detection) | |
| combined_output = f"{summary_text}\n\n--- CSV Data ---\n{csv_data}" | |
| return plot_img, combined_output | |
| finally: | |
| # Clean up temporary file | |
| if os.path.exists(temp_file_path): | |
| os.unlink(temp_file_path) | |
| elif mime_type.startswith('image/'): | |
| # Handle image | |
| file_extension = mime_type.split('/')[-1] | |
| if file_extension == 'jpeg': | |
| file_extension = 'jpg' | |
| # Check if it's a supported image format | |
| if not is_image(file_extension): | |
| from PIL import Image | |
| error_img = Image.new('RGB', (400, 200), 'white') | |
| return error_img, f"Unsupported image format: {file_extension}. Supported formats: png, jpg, jpeg, bmp, tiff, webp." | |
| # Convert to PIL Image | |
| from PIL import Image | |
| import io | |
| image = Image.open(io.BytesIO(file_data)) | |
| return process_image_gradio(image, pupil_selection, tv_model, blink_detection) | |
| else: | |
| # Unknown MIME type | |
| from PIL import Image | |
| error_img = Image.new('RGB', (400, 200), 'white') | |
| return error_img, f"Unsupported media type: {mime_type}. Please provide a video or image file." | |
| except Exception as e: | |
| from PIL import Image | |
| error_img = Image.new('RGB', (400, 200), 'white') | |
| return error_img, f"Error processing base64 media: {str(e)}" | |
| def process_media_unified_with_fallback(text_input, file_input, pupil_selection, tv_model, blink_detection): | |
| """ | |
| Wrapper function that handles both text input (base64) and file input. | |
| Args: | |
| text_input: str - Base64 data URL or empty string | |
| file_input: File object or None | |
| pupil_selection: str - "left_pupil", "right_pupil", or "both" | |
| tv_model: str - "ResNet18" or "ResNet50" | |
| blink_detection: bool - whether to detect blinks | |
| Returns: | |
| tuple: (result_image, result_text) | |
| """ | |
| # Prioritize text input (base64) if provided | |
| if text_input and text_input.strip(): | |
| return process_media_unified(text_input.strip(), pupil_selection, tv_model, blink_detection) | |
| # Fallback to file input if no text input | |
| if file_input is not None: | |
| return process_media_unified(file_input, pupil_selection, tv_model, blink_detection) | |
| # No input provided | |
| from PIL import Image | |
| error_img = Image.new('RGB', (400, 200), 'white') | |
| return error_img, "No input provided. Please provide either a base64 data URL or upload a file." | |
| def process_media_unified(media_input, pupil_selection, tv_model, blink_detection): | |
| """ | |
| Unified processing function that handles both images and videos. | |
| Args: | |
| media_input: Either an image (PIL), video file path, or base64 data URL | |
| pupil_selection: str - "left_pupil", "right_pupil", or "both" | |
| tv_model: str - "ResNet18" or "ResNet50" | |
| blink_detection: bool - whether to detect blinks | |
| Returns: | |
| tuple: (result_image, result_text) | |
| """ | |
| try: | |
| # Handle None or empty input | |
| if media_input is None or (isinstance(media_input, str) and not media_input.strip()): | |
| from PIL import Image | |
| error_img = Image.new('RGB', (400, 200), 'white') | |
| return error_img, "No media input provided. Please upload an image or video or provide a base64 data URL." | |
| # Check if input is a base64 data URL (string input) | |
| if isinstance(media_input, str) and media_input.strip().startswith('data:'): | |
| return process_base64_media(media_input.strip(), pupil_selection, tv_model, blink_detection) | |
| # Check if input is an image or video file object | |
| if hasattr(media_input, 'name'): | |
| # It's a file object (video) | |
| file_path = media_input.name | |
| # Extract file extension from the path | |
| import os | |
| file_extension = os.path.splitext(file_path)[1][1:] # Remove the dot | |
| if is_video(file_extension): | |
| plot_img, csv_data, summary_text = process_video_gradio(media_input, pupil_selection, tv_model, blink_detection) | |
| combined_output = f"{summary_text}\n\n--- CSV Data ---\n{csv_data}" | |
| return plot_img, combined_output | |
| elif is_image(file_extension): | |
| # Convert file to PIL Image | |
| from PIL import Image | |
| image = Image.open(file_path) | |
| return process_image_gradio(image, pupil_selection, tv_model, blink_detection) | |
| else: | |
| # Unknown file type | |
| from PIL import Image | |
| error_img = Image.new('RGB', (400, 200), 'white') | |
| return error_img, f"Unsupported file type: {file_extension}. Supported video formats: mp4, avi, mov, mkv, webm, flv, wmv. Supported image formats: png, jpg, jpeg, bmp, tiff, webp." | |
| else: | |
| # It's a PIL Image or other format | |
| if media_input is not None: | |
| return process_image_gradio(media_input, pupil_selection, tv_model, blink_detection) | |
| else: | |
| # Fallback for any other None case | |
| from PIL import Image | |
| error_img = Image.new('RGB', (400, 200), 'white') | |
| return error_img, "Invalid media input format." | |
| except Exception as e: | |
| error_msg = f"Error processing media: {str(e)}" | |
| from PIL import Image | |
| error_img = Image.new('RGB', (400, 200), 'white') | |
| return error_img, error_msg | |
| def create_gradio_interface(): | |
| """Create and configure the Gradio interface with proper API support.""" | |
| print("🔧 Creating Gradio interface...") | |
| try: | |
| # Create a unified interface that can handle both images and videos | |
| with gr.Blocks(title="👁️ PupilSense 👁️🕵️♂️") as demo: | |
| gr.Markdown("# 👁️ PupilSense - Pupil Diameter Analysis") | |
| gr.Markdown("Upload an image or video to estimate pupil diameter using deep learning models.") | |
| with gr.Tab("Image Processing"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image(type="pil", label="Upload Image") | |
| image_pupil_selection = gr.Dropdown( | |
| ["left_pupil", "right_pupil", "both"], | |
| value="both", | |
| label="Pupil Selection" | |
| ) | |
| image_model = gr.Dropdown( | |
| ["ResNet18", "ResNet50"], | |
| value="ResNet18", | |
| label="Model" | |
| ) | |
| image_blink_detection = gr.Checkbox(value=True, label="Detect Blinks") | |
| image_submit = gr.Button("Process Image", variant="primary") | |
| with gr.Column(): | |
| image_output = gr.Image(label="Results") | |
| image_text_output = gr.Textbox(label="Pupil Diameter Results", lines=5) | |
| image_submit.click( | |
| fn=process_image_simple, | |
| inputs=[image_input, image_pupil_selection, image_model, image_blink_detection], | |
| outputs=[image_output, image_text_output] | |
| ) | |
| with gr.Tab("Video Processing"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| video_input = gr.Video(label="Upload Video") | |
| video_pupil_selection = gr.Dropdown( | |
| ["left_pupil", "right_pupil", "both"], | |
| value="both", | |
| label="Pupil Selection" | |
| ) | |
| video_model = gr.Dropdown( | |
| ["ResNet18", "ResNet50"], | |
| value="ResNet18", | |
| label="Model" | |
| ) | |
| video_blink_detection = gr.Checkbox(value=True, label="Detect Blinks") | |
| video_submit = gr.Button("Process Video", variant="primary") | |
| with gr.Column(): | |
| video_output = gr.Image(label="Diameter Analysis") | |
| video_text_output = gr.Textbox(label="Summary Statistics", lines=10) | |
| video_submit.click( | |
| fn=process_video_simple, | |
| inputs=[video_input, video_pupil_selection, video_model, video_blink_detection], | |
| outputs=[video_output, video_text_output] | |
| ) | |
| # Add a unified API endpoint that can handle both images and videos | |
| with gr.Tab("API Testing"): | |
| gr.Markdown("### API Endpoint for External Access") | |
| gr.Markdown("This endpoint can process both images and videos programmatically.") | |
| gr.Markdown("**Input Format**: Accepts base64 data URLs (e.g., `data:video/mp4;base64,...`) or file uploads.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| api_media_input = gr.Textbox( | |
| label="Media Input", | |
| placeholder="Paste base64 data URL (data:video/mp4;base64,...) or upload file below", | |
| lines=3 | |
| ) | |
| api_file_input = gr.File(label="Or Upload Image/Video File (optional)", visible=True) | |
| api_pupil_selection = gr.Dropdown( | |
| ["left_pupil", "right_pupil", "both"], | |
| value="both", | |
| label="Pupil Selection" | |
| ) | |
| api_model = gr.Dropdown( | |
| ["ResNet18", "ResNet50"], | |
| value="ResNet18", | |
| label="Model" | |
| ) | |
| api_blink_detection = gr.Checkbox(value=True, label="Detect Blinks") | |
| api_submit = gr.Button("Process Media", variant="primary") | |
| with gr.Column(): | |
| api_output = gr.Image(label="Results") | |
| api_text_output = gr.Textbox(label="Analysis Results", lines=10) | |
| api_submit.click( | |
| fn=process_media_unified, | |
| inputs=[api_media_input, api_pupil_selection, api_model, api_blink_detection], | |
| outputs=[api_output, api_text_output] | |
| ) | |
| print("✅ Gradio interface created successfully") | |
| return demo | |
| except Exception as e: | |
| print(f"❌ Error creating Gradio interface: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| raise e | |
| def process_image_simple(image, pupil_selection, tv_model, blink_detection): | |
| """Simplified image processing function for gr.Interface.""" | |
| result_image, result_text = process_image_gradio(image, pupil_selection, tv_model, blink_detection) | |
| return result_image, result_text | |
| def process_video_simple(video_file, pupil_selection, tv_model, blink_detection): | |
| """Simplified video processing function for gr.Interface.""" | |
| plot_img, csv_data, summary_text = process_video_gradio(video_file, pupil_selection, tv_model, blink_detection) | |
| # Combine summary and CSV data for single output | |
| combined_output = f"{summary_text}\n\n--- CSV Data ---\n{csv_data}" | |
| return plot_img, combined_output | |
| if __name__ == "__main__": | |
| demo = create_gradio_interface() | |
| demo.launch() | |