File size: 2,901 Bytes
8c986fb
acbd068
92494bc
acbd068
 
 
 
 
 
6818fcc
acbd068
 
 
 
92494bc
 
 
acbd068
92494bc
 
 
acbd068
8c986fb
acbd068
 
 
 
8c986fb
acbd068
 
50a6e5b
8c986fb
acbd068
 
 
8c986fb
50a6e5b
 
 
 
8c986fb
50a6e5b
acbd068
8c986fb
acbd068
 
 
 
8c986fb
acbd068
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c986fb
acbd068
 
 
 
 
8c986fb
acbd068
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c986fb
92494bc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import gradio as gr
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from tibbtech import neuroai
import tempfile
import os
import imageio
from matplotlib import cm
import imageio.v3 as iio 

# Image processing
def process_image(img_path: str):
    img_np = plt.imread(img_path)
    if img_np.ndim == 3 and img_np.shape[2] > 3:
        img_np = img_np[:, :, :3]

    img_tensor = neuroai.wta(img_np).numpy()  # (1,H,W)
    inferno_colored = plt.cm.inferno(img_tensor[0])
    inferno_img = Image.fromarray((inferno_colored[:, :, :3] * 255).astype(np.uint8))

    return img_path, inferno_img  # return path for original gallery

# Video processing with imageio (more browser-compatible)
def process_video(video_path: str):
    video_tensor = neuroai.run_wta(video_path)  # (T,H,W)
    T, H, W = video_tensor.shape

    temp_path = tempfile.mktemp(suffix=".mp4")
    video_tensor = np.clip(video_tensor, 0, 1)
    cmap = cm.get_cmap("inferno")

    writer = imageio.get_writer(
        temp_path, fps=30, codec="libx264", ffmpeg_params=["-pix_fmt", "yuv420p"]
    )

    for t in range(T):
        frame = video_tensor[t]
        colored = (cmap(frame)[..., :3] * 255).astype(np.uint8)
        writer.append_data(colored)

    writer.close()
    return video_path, temp_path

# Unified processing
def process_files(files):
    if not isinstance(files, list):
        files = [files]

    originals = []
    processed = []

    for file in files:
        ext = os.path.splitext(file.name)[1].lower()
        if ext in [".jpg", ".jpeg", ".png", ".bmp", ".tiff"]:
            orig, proc = process_image(file.name)
            originals.append(orig)
            processed.append(proc)
        elif ext in [".mp4", ".avi", ".mov", ".mkv"]:
            orig, proc = process_video(file.name)
            originals.append(orig)
            processed.append(proc)

    return (
        gr.update(visible=bool(originals), value=originals),
        gr.update(visible=bool(processed), value=processed)
    )

# Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("## Upload Image(s) or Video(s) for Processing")

    file_input = gr.File(file_types=["image", "video"], label="Upload Images/Videos", )

    with gr.Row():
        original_gallery = gr.Gallery(label="Original", visible=False)
        processed_gallery = gr.Gallery(label="Processed", visible=False)

    file_input.change(
        fn=process_files,
        inputs=file_input,
        outputs=[original_gallery, processed_gallery]
    )

    example_files = [
        ["sample1.jpg"],       # image
        ["sample2.jpg"],       # image
        ["video_sample.mp4"]   # video
    ]

    gr.Examples(
        examples=example_files,
        inputs=file_input,
        fn=process_files,
        outputs=[original_gallery, processed_gallery],
        label="Try our sample images or video"
    )

demo.launch(debug=True)