pupillometry / gradio_app.py
txarst's picture
gradio base64 fix02
229daef
import sys
import os
import os.path as osp
import gradio as gr
import numpy as np
import tempfile
from PIL import Image, ImageOps
import cv2
import matplotlib.pyplot as plt
import io
import base64
root_path = osp.abspath(osp.join(__file__, osp.pardir))
sys.path.append(root_path)
from registry_utils import import_registered_modules
from gradio_utils import (
is_image,
is_video,
extract_frames,
resize_frame,
CAM_METHODS,
process_frames_gradio,
)
import_registered_modules()
def process_image_gradio(image, pupil_selection, tv_model, blink_detection):
"""
Process a single image and return results for Gradio interface.
Args:
image: PIL Image or numpy array
pupil_selection: str - "left_pupil", "right_pupil", or "both"
tv_model: str - "ResNet18" or "ResNet50"
blink_detection: bool - whether to detect blinks
Returns:
tuple: (input_image, cam_overlay, diameter_text, results_plot)
"""
try:
# Convert to PIL Image if needed
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
# Handle EXIF rotation
image = ImageOps.exif_transpose(image)
# Resize image
image = resize_frame(image, max_width=640, max_height=480)
# Process the image using Gradio-compatible function
input_frames, output_frames, predicted_diameters = process_frames_gradio(
input_imgs=[image],
tv_model=tv_model,
pupil_selection=pupil_selection,
blink_detection=blink_detection,
)
# Check if processing failed (empty results)
if not input_frames or not output_frames or not predicted_diameters:
error_msg = "Could not detect face/eyes in the image. Please try with a clearer image showing eyes."
error_img = Image.new('RGB', (400, 200), 'white')
return error_img, error_msg
# Create visualization
results = []
diameter_results = []
for eye_type in input_frames.keys():
input_img = input_frames[eye_type][-1]
output_img = output_frames[eye_type][-1]
diameter = predicted_diameters[eye_type][0]
# Create side-by-side comparison
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
ax1.imshow(input_img)
ax1.set_title(f"Input - {eye_type.replace('_', ' ').title()}")
ax1.axis('off')
ax2.imshow(output_img)
ax2.set_title(f"CAM Overlay - {eye_type.replace('_', ' ').title()}")
ax2.axis('off')
plt.tight_layout()
# Convert plot to image
buf = io.BytesIO()
plt.savefig(buf, format='png', dpi=150, bbox_inches='tight')
buf.seek(0)
plot_img = Image.open(buf)
plt.close()
results.append(plot_img)
# Format diameter result
if isinstance(diameter, str):
diameter_results.append(f"{eye_type.replace('_', ' ').title()}: {diameter}")
else:
diameter_results.append(f"{eye_type.replace('_', ' ').title()}: {diameter:.2f} mm")
# Combine results if multiple eyes
if len(results) == 1:
final_image = results[0]
else:
# Combine multiple eye results
total_width = sum(img.width for img in results)
max_height = max(img.height for img in results)
final_image = Image.new('RGB', (total_width, max_height), 'white')
x_offset = 0
for img in results:
final_image.paste(img, (x_offset, 0))
x_offset += img.width
diameter_text = "\n".join(diameter_results)
return final_image, diameter_text
except Exception as e:
error_msg = f"Error processing image: {str(e)}"
# Create error image
error_img = Image.new('RGB', (400, 200), 'white')
return error_img, error_msg
def process_video_gradio(video_file, pupil_selection, tv_model, blink_detection):
"""
Process a video file and return results for Gradio interface.
Args:
video_file: file path or file object
pupil_selection: str - "left_pupil", "right_pupil", or "both"
tv_model: str - "ResNet18" or "ResNet50"
blink_detection: bool - whether to detect blinks
Returns:
tuple: (results_plot, diameter_data, summary_text)
"""
try:
# Handle video file
if hasattr(video_file, 'name'):
video_path = video_file.name
else:
video_path = video_file
# Extract frames
video_frames = extract_frames(video_path)
if not video_frames:
return None, "No frames extracted from video", "Error: Could not process video"
# Resize frames
resized_frames = []
for frame in video_frames:
if isinstance(frame, np.ndarray):
frame = Image.fromarray(frame)
input_img = resize_frame(frame, max_width=640, max_height=480)
resized_frames.append(input_img)
# Process video frames using Gradio-compatible function
input_frames, output_frames, predicted_diameters = process_frames_gradio(
input_imgs=resized_frames,
tv_model=tv_model,
pupil_selection=pupil_selection,
blink_detection=blink_detection,
)
# Check if processing failed (empty results)
if not input_frames or not output_frames or not predicted_diameters:
error_msg = "Could not process video. MediaPipe may have issues in this environment."
error_img = Image.new('RGB', (400, 200), 'white')
return error_img, "", error_msg
# Create results visualization
fig, axes = plt.subplots(len(predicted_diameters), 1, figsize=(12, 6 * len(predicted_diameters)))
if len(predicted_diameters) == 1:
axes = [axes]
summary_stats = []
for idx, (eye_type, diameters) in enumerate(predicted_diameters.items()):
# Filter out non-numeric values (like "blink")
numeric_diameters = [d for d in diameters if isinstance(d, (int, float))]
frame_numbers = list(range(len(diameters)))
# Plot diameter over time
axes[idx].plot(frame_numbers, diameters, marker='o', markersize=2)
axes[idx].set_title(f"Pupil Diameter Over Time - {eye_type.replace('_', ' ').title()}")
axes[idx].set_xlabel("Frame Number")
axes[idx].set_ylabel("Diameter (mm)")
axes[idx].grid(True, alpha=0.3)
# Calculate statistics
if numeric_diameters:
mean_diameter = np.mean(numeric_diameters)
std_diameter = np.std(numeric_diameters)
min_diameter = np.min(numeric_diameters)
max_diameter = np.max(numeric_diameters)
summary_stats.append(f"{eye_type.replace('_', ' ').title()}:")
summary_stats.append(f" Mean: {mean_diameter:.2f} mm")
summary_stats.append(f" Std: {std_diameter:.2f} mm")
summary_stats.append(f" Min: {min_diameter:.2f} mm")
summary_stats.append(f" Max: {max_diameter:.2f} mm")
summary_stats.append("")
plt.tight_layout()
# Convert plot to image
buf = io.BytesIO()
plt.savefig(buf, format='png', dpi=150, bbox_inches='tight')
buf.seek(0)
plot_img = Image.open(buf)
plt.close()
# Create summary text
summary_text = f"Processed {len(video_frames)} frames\n\n" + "\n".join(summary_stats)
# Create CSV data for download
csv_data = "Frame,Eye_Type,Diameter_mm\n"
for eye_type, diameters in predicted_diameters.items():
for frame_idx, diameter in enumerate(diameters):
csv_data += f"{frame_idx},{eye_type},{diameter}\n"
# Clean up temporary files if they exist
# (output_video_path not used in this implementation)
return plot_img, csv_data, summary_text
except Exception as e:
error_msg = f"Error processing video: {str(e)}"
error_img = Image.new('RGB', (400, 200), 'white')
return error_img, "", error_msg
def process_base64_media(data_url, pupil_selection, tv_model, blink_detection):
"""
Process base64 encoded media (images or videos).
Args:
data_url: str - Base64 data URL (e.g., "data:video/mp4;base64,...")
pupil_selection: str - "left_pupil", "right_pupil", or "both"
tv_model: str - "ResNet18" or "ResNet50"
blink_detection: bool - whether to detect blinks
Returns:
tuple: (result_image, result_text)
"""
try:
# Parse the data URL
if not data_url.startswith('data:'):
raise ValueError("Invalid data URL format")
# Extract MIME type and base64 data
header, base64_data = data_url.split(',', 1)
mime_type = header.split(';')[0].replace('data:', '')
# Decode base64 data
import base64
file_data = base64.b64decode(base64_data)
# Determine file type from MIME type
if mime_type.startswith('video/'):
# Handle video
file_extension = mime_type.split('/')[-1]
if file_extension == 'quicktime':
file_extension = 'mov'
elif file_extension == 'x-msvideo':
file_extension = 'avi'
# Check if it's a supported video format
if not is_video(file_extension):
from PIL import Image
error_img = Image.new('RGB', (400, 200), 'white')
return error_img, f"Unsupported video format: {file_extension}. Supported formats: mp4, avi, mov, mkv, webm, flv, wmv."
# Create temporary file for video processing
import tempfile
import os
with tempfile.NamedTemporaryFile(delete=False, suffix=f'.{file_extension}') as temp_file:
temp_file.write(file_data)
temp_file_path = temp_file.name
try:
# Process video using the temporary file
plot_img, csv_data, summary_text = process_video_gradio(temp_file_path, pupil_selection, tv_model, blink_detection)
combined_output = f"{summary_text}\n\n--- CSV Data ---\n{csv_data}"
return plot_img, combined_output
finally:
# Clean up temporary file
if os.path.exists(temp_file_path):
os.unlink(temp_file_path)
elif mime_type.startswith('image/'):
# Handle image
file_extension = mime_type.split('/')[-1]
if file_extension == 'jpeg':
file_extension = 'jpg'
# Check if it's a supported image format
if not is_image(file_extension):
from PIL import Image
error_img = Image.new('RGB', (400, 200), 'white')
return error_img, f"Unsupported image format: {file_extension}. Supported formats: png, jpg, jpeg, bmp, tiff, webp."
# Convert to PIL Image
from PIL import Image
import io
image = Image.open(io.BytesIO(file_data))
return process_image_gradio(image, pupil_selection, tv_model, blink_detection)
else:
# Unknown MIME type
from PIL import Image
error_img = Image.new('RGB', (400, 200), 'white')
return error_img, f"Unsupported media type: {mime_type}. Please provide a video or image file."
except Exception as e:
from PIL import Image
error_img = Image.new('RGB', (400, 200), 'white')
return error_img, f"Error processing base64 media: {str(e)}"
def process_media_unified_with_fallback(text_input, file_input, pupil_selection, tv_model, blink_detection):
"""
Wrapper function that handles both text input (base64) and file input.
Args:
text_input: str - Base64 data URL or empty string
file_input: File object or None
pupil_selection: str - "left_pupil", "right_pupil", or "both"
tv_model: str - "ResNet18" or "ResNet50"
blink_detection: bool - whether to detect blinks
Returns:
tuple: (result_image, result_text)
"""
# Prioritize text input (base64) if provided
if text_input and text_input.strip():
return process_media_unified(text_input.strip(), pupil_selection, tv_model, blink_detection)
# Fallback to file input if no text input
if file_input is not None:
return process_media_unified(file_input, pupil_selection, tv_model, blink_detection)
# No input provided
from PIL import Image
error_img = Image.new('RGB', (400, 200), 'white')
return error_img, "No input provided. Please provide either a base64 data URL or upload a file."
def process_media_unified(media_input, pupil_selection, tv_model, blink_detection):
"""
Unified processing function that handles both images and videos.
Args:
media_input: Either an image (PIL), video file path, or base64 data URL
pupil_selection: str - "left_pupil", "right_pupil", or "both"
tv_model: str - "ResNet18" or "ResNet50"
blink_detection: bool - whether to detect blinks
Returns:
tuple: (result_image, result_text)
"""
try:
# Handle None or empty input
if media_input is None or (isinstance(media_input, str) and not media_input.strip()):
from PIL import Image
error_img = Image.new('RGB', (400, 200), 'white')
return error_img, "No media input provided. Please upload an image or video or provide a base64 data URL."
# Check if input is a base64 data URL (string input)
if isinstance(media_input, str) and media_input.strip().startswith('data:'):
return process_base64_media(media_input.strip(), pupil_selection, tv_model, blink_detection)
# Check if input is an image or video file object
if hasattr(media_input, 'name'):
# It's a file object (video)
file_path = media_input.name
# Extract file extension from the path
import os
file_extension = os.path.splitext(file_path)[1][1:] # Remove the dot
if is_video(file_extension):
plot_img, csv_data, summary_text = process_video_gradio(media_input, pupil_selection, tv_model, blink_detection)
combined_output = f"{summary_text}\n\n--- CSV Data ---\n{csv_data}"
return plot_img, combined_output
elif is_image(file_extension):
# Convert file to PIL Image
from PIL import Image
image = Image.open(file_path)
return process_image_gradio(image, pupil_selection, tv_model, blink_detection)
else:
# Unknown file type
from PIL import Image
error_img = Image.new('RGB', (400, 200), 'white')
return error_img, f"Unsupported file type: {file_extension}. Supported video formats: mp4, avi, mov, mkv, webm, flv, wmv. Supported image formats: png, jpg, jpeg, bmp, tiff, webp."
else:
# It's a PIL Image or other format
if media_input is not None:
return process_image_gradio(media_input, pupil_selection, tv_model, blink_detection)
else:
# Fallback for any other None case
from PIL import Image
error_img = Image.new('RGB', (400, 200), 'white')
return error_img, "Invalid media input format."
except Exception as e:
error_msg = f"Error processing media: {str(e)}"
from PIL import Image
error_img = Image.new('RGB', (400, 200), 'white')
return error_img, error_msg
def create_gradio_interface():
"""Create and configure the Gradio interface with proper API support."""
print("🔧 Creating Gradio interface...")
try:
# Create a unified interface that can handle both images and videos
with gr.Blocks(title="👁️ PupilSense 👁️🕵️‍♂️") as demo:
gr.Markdown("# 👁️ PupilSense - Pupil Diameter Analysis")
gr.Markdown("Upload an image or video to estimate pupil diameter using deep learning models.")
with gr.Tab("Image Processing"):
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Upload Image")
image_pupil_selection = gr.Dropdown(
["left_pupil", "right_pupil", "both"],
value="both",
label="Pupil Selection"
)
image_model = gr.Dropdown(
["ResNet18", "ResNet50"],
value="ResNet18",
label="Model"
)
image_blink_detection = gr.Checkbox(value=True, label="Detect Blinks")
image_submit = gr.Button("Process Image", variant="primary")
with gr.Column():
image_output = gr.Image(label="Results")
image_text_output = gr.Textbox(label="Pupil Diameter Results", lines=5)
image_submit.click(
fn=process_image_simple,
inputs=[image_input, image_pupil_selection, image_model, image_blink_detection],
outputs=[image_output, image_text_output]
)
with gr.Tab("Video Processing"):
with gr.Row():
with gr.Column():
video_input = gr.Video(label="Upload Video")
video_pupil_selection = gr.Dropdown(
["left_pupil", "right_pupil", "both"],
value="both",
label="Pupil Selection"
)
video_model = gr.Dropdown(
["ResNet18", "ResNet50"],
value="ResNet18",
label="Model"
)
video_blink_detection = gr.Checkbox(value=True, label="Detect Blinks")
video_submit = gr.Button("Process Video", variant="primary")
with gr.Column():
video_output = gr.Image(label="Diameter Analysis")
video_text_output = gr.Textbox(label="Summary Statistics", lines=10)
video_submit.click(
fn=process_video_simple,
inputs=[video_input, video_pupil_selection, video_model, video_blink_detection],
outputs=[video_output, video_text_output]
)
# Add a unified API endpoint that can handle both images and videos
with gr.Tab("API Testing"):
gr.Markdown("### API Endpoint for External Access")
gr.Markdown("This endpoint can process both images and videos programmatically.")
gr.Markdown("**Input Format**: Accepts base64 data URLs (e.g., `data:video/mp4;base64,...`) or file uploads.")
with gr.Row():
with gr.Column():
api_media_input = gr.Textbox(
label="Media Input",
placeholder="Paste base64 data URL (data:video/mp4;base64,...) or upload file below",
lines=3
)
api_file_input = gr.File(label="Or Upload Image/Video File (optional)", visible=True)
api_pupil_selection = gr.Dropdown(
["left_pupil", "right_pupil", "both"],
value="both",
label="Pupil Selection"
)
api_model = gr.Dropdown(
["ResNet18", "ResNet50"],
value="ResNet18",
label="Model"
)
api_blink_detection = gr.Checkbox(value=True, label="Detect Blinks")
api_submit = gr.Button("Process Media", variant="primary")
with gr.Column():
api_output = gr.Image(label="Results")
api_text_output = gr.Textbox(label="Analysis Results", lines=10)
api_submit.click(
fn=process_media_unified,
inputs=[api_media_input, api_pupil_selection, api_model, api_blink_detection],
outputs=[api_output, api_text_output]
)
print("✅ Gradio interface created successfully")
return demo
except Exception as e:
print(f"❌ Error creating Gradio interface: {e}")
import traceback
traceback.print_exc()
raise e
def process_image_simple(image, pupil_selection, tv_model, blink_detection):
"""Simplified image processing function for gr.Interface."""
result_image, result_text = process_image_gradio(image, pupil_selection, tv_model, blink_detection)
return result_image, result_text
def process_video_simple(video_file, pupil_selection, tv_model, blink_detection):
"""Simplified video processing function for gr.Interface."""
plot_img, csv_data, summary_text = process_video_gradio(video_file, pupil_selection, tv_model, blink_detection)
# Combine summary and CSV data for single output
combined_output = f"{summary_text}\n\n--- CSV Data ---\n{csv_data}"
return plot_img, combined_output
if __name__ == "__main__":
demo = create_gradio_interface()
demo.launch()