File size: 7,551 Bytes
3e978e1
ed88963
 
66bcb74
ed88963
 
 
3e978e1
ed88963
66bcb74
579a3a7
4747f61
3e978e1
66bcb74
579a3a7
ed88963
66bcb74
ed88963
 
66bcb74
 
ed88963
66bcb74
3e978e1
 
ed88963
3e978e1
579a3a7
3e978e1
66bcb74
3e978e1
 
 
 
 
 
 
 
 
 
 
66bcb74
3e978e1
ed88963
66bcb74
1bbb7db
66bcb74
1bbb7db
66bcb74
 
 
 
 
 
 
 
 
 
 
 
579a3a7
ed88963
579a3a7
1bbb7db
 
579a3a7
66bcb74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1bbb7db
66bcb74
 
 
1bbb7db
 
66bcb74
 
ed88963
 
66bcb74
 
 
3e978e1
 
66bcb74
1bbb7db
66bcb74
1bbb7db
61a9e69
3e978e1
1bbb7db
ed88963
579a3a7
66bcb74
 
 
 
3e978e1
 
 
66bcb74
1bbb7db
3e978e1
1bbb7db
66bcb74
 
3e978e1
66bcb74
3e978e1
 
66bcb74
 
3e978e1
579a3a7
ed88963
66bcb74
3e978e1
 
 
 
66bcb74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
579a3a7
66bcb74
 
ed88963
579a3a7
66bcb74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
579a3a7
66bcb74
 
1bbb7db
 
66bcb74
 
 
 
 
 
 
 
 
 
 
 
1bbb7db
66bcb74
579a3a7
66bcb74
 
 
ed88963
 
 
 
 
579a3a7
3e978e1
579a3a7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
# app_seedvr.py

import os
import sys
from pathlib import Path
from typing import Optional
import gradio as gr
import cv2

# --- SERVER LOGIC INTEGRATION ---
try:
    from api.seedvr_server import SeedVRServer
except ImportError as e:
    print(f"FATAL ERROR: Could not import SeedVRServer. Details: {e}")
    raise

# --- INITIALIZATION ---
server = SeedVRServer()

# --- HELPER FUNCTIONS ---

def _is_video(path: str) -> bool:
    """Checks if a file path corresponds to a video type."""
    if not path: return False
    import mimetypes
    mime, _ = mimetypes.guess_type(path)
    return (mime or "").startswith("video")

def _extract_first_frame(video_path: str) -> Optional[str]:
    """Extracts the first frame from a video and saves it as a JPG image."""
    if not video_path or not os.path.exists(video_path): return None
    try:
        vid_cap = cv2.VideoCapture(video_path)
        if not vid_cap.isOpened(): return None
        success, image = vid_cap.read()
        vid_cap.release()
        if not success: return None
        image_path = Path(video_path).with_suffix(".jpg")
        cv2.imwrite(str(image_path), image)
        return str(image_path)
    except Exception as e:
        print(f"Error extracting first frame: {e}")
        return None

def on_file_upload(file_obj):
    """Callback triggered when a user uploads a file."""
    if file_obj is None:
        return 1
    if _is_video(file_obj.name):
        return gr.update(value=4, interactive=True)
    else:
        return gr.update(value=1, interactive=False)

# --- CORE INFERENCE FUNCTION ---

def run_inference_ui(
    input_file_path: Optional[str],
    resolution: str,
    sp_size: int,
    fps: float,
    progress=gr.Progress(track_tqdm=True)
):
    """
    The main callback function for Gradio, using generators (`yield`)
    for real-time UI updates.
    """
    # 1. Initial State & Validation
    yield (
        gr.update(interactive=False, value="Processing... 🚀"),
        gr.update(value=None, visible=False),
        gr.update(value=None, visible=False),
        gr.update(value=None, visible=False),
        gr.update(value="Waiting for logs...", visible=True)
    )

    if not input_file_path:
        gr.Warning("Please upload a media file first.")
        yield (
            gr.update(interactive=True, value="Restore Media"),
            None, None, None, gr.update(visible=False)
        )
        return

    log_buffer = ["▶ Starting inference process...\n"]
    yield gr.update(), None, None, None, ''.join(log_buffer)

    # CORREÇÃO APLICADA AQUI
    def progress_callback(step: float, desc: str):
        """A simple callback to append messages to our log buffer."""
        log_buffer.append(f"⏳ [{int(step*100)}%] {desc}\n")
        # A chamada correta para a API de progresso do Gradio
        progress(step, desc=desc)

    was_input_video = _is_video(input_file_path)

    try:
        # 2. Execute Inference
        progress_callback(0.1, "Calling backend engine...")
        yield gr.update(), None, None, None, ''.join(log_buffer)

        video_result_path = server.run_inference_direct(
            file_path=input_file_path,
            seed=42,
            res_h=int(resolution),
            res_w=int(resolution),
            sp_size=int(sp_size),
            fps=float(fps) if fps and fps > 0 else None,
            progress=progress,
        )
        
        progress_callback(1.0, "Inference complete! Processing final output...")
        yield gr.update(), None, None, None, ''.join(log_buffer)
        
        # 3. Process and Display Results
        final_image, final_video = None, None
        if was_input_video:
            final_video = video_result_path
            log_buffer.append(f"✅ Video result is ready.\n")
        else:
            final_image = _extract_first_frame(video_result_path)
            final_video = video_result_path
            log_buffer.append(f"✅ Image result extracted from video.\n")
        
        yield (
            gr.update(interactive=True, value="Restore Media"),
            gr.update(value=final_image, visible=final_image is not None),
            gr.update(value=final_video, visible=final_video is not None),
            gr.update(value=video_result_path, visible=video_result_path is not None),
            ''.join(log_buffer)
        )

    except Exception as e:
        error_message = f"❌ Inference failed: {e}"
        gr.Error(error_message)
        print(error_message)
        import traceback
        traceback.print_exc()
        
        yield (
            gr.update(interactive=True, value="Restore Media"),
            None, None, None,
            gr.update(value=f"{''.join(log_buffer)}\n{error_message}", visible=True)
        )


# --- GRADIO UI LAYOUT ---
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue"), title="SeedVR Media Restoration") as demo:
    # Header
    gr.Markdown(
        """
        <div style='text-align: center; margin-bottom: 20px;'>
            <h1>📸 SeedVR - Image & Video Restoration 🚀</h1>
            <p>High-quality media upscaling powered by SeedVR-3B. Upload your file and see the magic.</p>
        </div>
        """
    )
    with gr.Row():
        with gr.Column(scale=1):
            gr.Markdown("### 1. Upload Media")
            input_media = gr.File(label="Input File (Video or Image)", type="filepath")
            gr.Markdown("### 2. Configure Settings")
            with gr.Accordion("Generation Parameters", open=True):
                resolution_select = gr.Dropdown(
                    label="Resolution (Short Edge)",
                    choices=["480", "560", "720", "960", "1024"],
                    value="480",
                    info="The output height and width will be set to this value."
                )
                sp_size_slider = gr.Slider(
                    label="Sequence Parallelism (sp_size)",
                    minimum=1, maximum=16, step=1, value=4,
                    info="For multi-GPU videos. This will be set to 1 for images."
                )
                fps_out = gr.Number(label="Output FPS (for Videos)", value=24, precision=0, info="Set to 0 to use the original FPS.")
            run_button = gr.Button("Restore Media", variant="primary", icon="✨")
        with gr.Column(scale=2):
            gr.Markdown("### 3. Results")
            log_window = gr.Textbox(
                label="Inference Log 📝", lines=8, max_lines=15,
                interactive=False, visible=False, autoscroll=True,
            )
            output_image = gr.Image(label="Image Result", show_download_button=True, type="filepath", visible=False)
            output_video = gr.Video(label="Video Result", visible=False)
            output_download = gr.File(label="Download Full Result (Video)", visible=False)
    gr.Markdown(
        """
        ---
        *Space and Docker were developed by Carlex.*
        *Contact: Email: Carlex22@gmail.com | GitHub: [carlex22](https://github.com/carlex22)*
        """
    )
    
    input_media.upload(fn=on_file_upload, inputs=[input_media], outputs=[sp_size_slider])
    
    run_button.click(
        fn=run_inference_ui,
        inputs=[input_media, resolution_select, sp_size_slider, fps_out],
        outputs=[run_button, output_image, output_video, output_download, log_window],
    )

if __name__ == "__main__":
    demo.launch(
        server_name=os.getenv("GRADIO_SERVER_NAME", "0.0.0.0"),
        server_port=int(os.getenv("GRADIO_SERVER_PORT", "7860")),
        show_error=True
    )