Winner_Take_All / app.py
TibbtechUser's picture
Upload app.py
50a6e5b verified
import gradio as gr
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from tibbtech import neuroai
import tempfile
import os
import imageio
from matplotlib import cm
import imageio.v3 as iio
# Image processing
def process_image(img_path: str):
img_np = plt.imread(img_path)
if img_np.ndim == 3 and img_np.shape[2] > 3:
img_np = img_np[:, :, :3]
img_tensor = neuroai.wta(img_np).numpy() # (1,H,W)
inferno_colored = plt.cm.inferno(img_tensor[0])
inferno_img = Image.fromarray((inferno_colored[:, :, :3] * 255).astype(np.uint8))
return img_path, inferno_img # return path for original gallery
# Video processing with imageio (more browser-compatible)
def process_video(video_path: str):
video_tensor = neuroai.run_wta(video_path) # (T,H,W)
T, H, W = video_tensor.shape
temp_path = tempfile.mktemp(suffix=".mp4")
video_tensor = np.clip(video_tensor, 0, 1)
cmap = cm.get_cmap("inferno")
writer = imageio.get_writer(
temp_path, fps=30, codec="libx264", ffmpeg_params=["-pix_fmt", "yuv420p"]
)
for t in range(T):
frame = video_tensor[t]
colored = (cmap(frame)[..., :3] * 255).astype(np.uint8)
writer.append_data(colored)
writer.close()
return video_path, temp_path
# Unified processing
def process_files(files):
if not isinstance(files, list):
files = [files]
originals = []
processed = []
for file in files:
ext = os.path.splitext(file.name)[1].lower()
if ext in [".jpg", ".jpeg", ".png", ".bmp", ".tiff"]:
orig, proc = process_image(file.name)
originals.append(orig)
processed.append(proc)
elif ext in [".mp4", ".avi", ".mov", ".mkv"]:
orig, proc = process_video(file.name)
originals.append(orig)
processed.append(proc)
return (
gr.update(visible=bool(originals), value=originals),
gr.update(visible=bool(processed), value=processed)
)
# Gradio interface
with gr.Blocks() as demo:
gr.Markdown("## Upload Image(s) or Video(s) for Processing")
file_input = gr.File(file_types=["image", "video"], label="Upload Images/Videos", )
with gr.Row():
original_gallery = gr.Gallery(label="Original", visible=False)
processed_gallery = gr.Gallery(label="Processed", visible=False)
file_input.change(
fn=process_files,
inputs=file_input,
outputs=[original_gallery, processed_gallery]
)
example_files = [
["sample1.jpg"], # image
["sample2.jpg"], # image
["video_sample.mp4"] # video
]
gr.Examples(
examples=example_files,
inputs=file_input,
fn=process_files,
outputs=[original_gallery, processed_gallery],
label="Try our sample images or video"
)
demo.launch(debug=True)