File size: 9,975 Bytes
1ec204c
ed88963
ac2701f
1ec204c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac2701f
1ec204c
 
 
 
 
 
 
 
 
 
 
 
 
 
ac2701f
1ec204c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b879288
 
1ec204c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ed88963
1ec204c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3e978e1
1ec204c
ed88963
1ec204c
 
 
 
 
 
 
 
 
 
 
 
fc85414
1ec204c
 
 
 
 
 
 
 
 
 
 
 
fc85414
b879288
1ec204c
 
 
 
 
 
 
 
 
 
 
 
 
fc85414
1ec204c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66bcb74
1ec204c
fc85414
1ec204c
 
 
 
 
 
b879288
1ec204c
 
 
 
 
 
ed88963
 
1ec204c
 
 
 
d1eb45a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
# app_seedvr.py

import os
import sys
from pathlib import Path
from typing import Optional
import gradio as gr
import cv2

# --- INTEGRAÇÃO COM A LÓGICA DO SERVIDOR ---
try:
    # Importa a classe SeedVRServer que agora atua como nossa biblioteca de inferência.
    from api.seedvr_server import SeedVRServer
except ImportError as e:
    print(f"ERRO FATAL: Não foi possível importar o SeedVRServer. Detalhes: {e}")
    # A aplicação não pode rodar sem a lógica do servidor.
    raise

# --- INICIALIZAÇÃO ---
# Cria uma instância única e persistente do servidor.
# A inicialização (clonar repo, baixar modelos) acontece apenas uma vez, no início.
server = SeedVRServer()

# --- FUNÇÕES AUXILIARES ---

def _is_video(path: str) -> bool:
    """Verifica se um caminho de arquivo corresponde a um tipo de vídeo."""
    if not path: return False
    import mimetypes
    mime, _ = mimetypes.guess_type(path)
    return (mime or "").startswith("video")

def _extract_first_frame(video_path: str) -> Optional[str]:
    """Extrai o primeiro frame de um vídeo e o salva como uma imagem JPG."""
    if not video_path or not os.path.exists(video_path): return None
    try:
        vid_cap = cv2.VideoCapture(video_path)
        if not vid_cap.isOpened():
            print(f"Erro: Não foi possível abrir o vídeo em {video_path}")
            return None
        success, image = vid_cap.read()
        vid_cap.release()
        if not success:
            print(f"Erro: Não foi possível ler o primeiro frame de {video_path}")
            return None
        
        # Salva o frame no mesmo diretório do vídeo, com extensão .jpg
        image_path = Path(video_path).with_suffix(".jpg")
        cv2.imwrite(str(image_path), image)
        return str(image_path)
    except Exception as e:
        print(f"Erro ao extrair o primeiro frame: {e}")
        return None

def on_file_upload(file_obj):
    """
    Callback acionado quando o usuário faz o upload de um arquivo.
    Verifica se o arquivo é um vídeo e sugere um `sp_size` apropriado.
    """
    if file_obj is None:
        # Limpa os resultados e o log se o arquivo for removido
        return gr.update(value=1), None, None, None, gr.update(value=None, visible=False)
    
    if _is_video(file_obj.name):
        # Para vídeos, sugere um valor padrão para multi-GPU e torna o slider interativo
        return gr.update(value=8, interactive=True), None, None, None, gr.update(value=None, visible=False)
    else:
        # Para imagens, trava o valor em 1
        return gr.update(value=1, interactive=False), None, None, None, gr.update(value=None, visible=False)

# --- FUNÇÃO PRINCIPAL DE INFERÊNCIA DA UI ---

def run_inference_ui(
    input_file_path: Optional[str],
    resolution: str,
    sp_size: int,
    fps: float,
    progress=gr.Progress(track_tqdm=True)
):
    """
    A função de callback principal do Gradio. Usa geradores (`yield`)
    para permitir atualizações da UI em tempo real durante a tarefa de longa duração.
    """
    # 1. Estado Inicial e Validação
    # No início, desabilita o botão, limpa resultados anteriores e mostra a janela de log.
    yield (
        gr.update(interactive=False, value="Processing... 🚀"),
        gr.update(value=None, visible=False),
        gr.update(value=None, visible=False),
        gr.update(value=None, visible=False),
        gr.update(value="▶ Starting inference process...\n", visible=True)
    )

    if not input_file_path:
        gr.Warning("Please upload a media file first.")
        # Reabilita o botão e esconde os componentes de saída
        yield (gr.update(interactive=True, value="Restore Media"), None, None, None, gr.update(visible=False))
        return

    log_buffer = ["▶ Starting inference process...\n"]
    last_log_message = ""
    was_input_video = _is_video(input_file_path)

    try:
        # Define um callback que será chamado pelo backend para atualizar o progresso e o log
        def progress_callback_wrapper(step: float, desc: str):
            """ Wrapper para formatar logs e atualizar o progresso. """
            nonlocal last_log_message
            # Só adiciona ao log se a mensagem for nova, para evitar poluição visual
            if desc != last_log_message:
                log_buffer.append(f"{desc}\n")
                last_log_message = desc
            # Atualiza o objeto de progresso do Gradio
            progress(step, desc=desc)

        # 2. Executa a Inferência
        # Chama o método direto do servidor, passando o nosso callback.
        video_result_path = server.run_inference_direct(
            file_path=input_file_path,
            seed=42, # Semente fixa conforme solicitado
            res_h=int(resolution),
            res_w=int(resolution), # Largura igual à altura
            sp_size=int(sp_size),
            fps=float(fps) if fps and fps > 0 else None,
            progress=progress_callback_wrapper, # Passa nossa função de callback
        )
        
        progress(1.0, desc="Complete!")
        log_buffer.append("✅ Inference complete! Processing final output...\n")
        
        # 3. Processa e Exibe os Resultados
        final_image, final_video = None, None
        if was_input_video:
            final_video = video_result_path
            log_buffer.append("✅ Video result is ready.\n")
        else: # Se a entrada foi uma imagem
            final_image = _extract_first_frame(video_result_path)
            final_video = video_result_path # Também disponibiliza o vídeo de 1 frame
            log_buffer.append("✅ Image result extracted from video.\n")
        
        # Yield final para mostrar os resultados e reabilitar o botão
        yield (
            gr.update(interactive=True, value="Restore Media"),
            gr.update(value=final_image, visible=final_image is not None),
            gr.update(value=final_video, visible=final_video is not None),
            gr.update(value=video_result_path, visible=video_result_path is not None),
            ''.join(log_buffer)
        )

    except Exception as e:
        error_message = f"❌ Inference failed: {e}"
        gr.Error(error_message)
        log_buffer.append(f"\n{error_message}")
        import traceback
        traceback.print_exc()
        
        # Yield para estado de erro: reabilita o botão e mostra o log com o erro
        yield (
            gr.update(interactive=True, value="Restore Media"),
            None, None, None,
            gr.update(value=''.join(log_buffer), visible=True)
        )

# --- LAYOUT DA INTERFACE GRÁFICA (GRADIO) ---

with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue"), title="SeedVR Media Restoration") as demo:
    # Cabeçalho
    gr.Markdown(
        """
        <div style='text-align: center; margin-bottom: 20px;'>
            <h1>📸 SeedVR - Image & Video Restoration 🚀</h1>
            <p>High-quality media upscaling powered by SeedVR-3B. Upload your file and see the magic.</p>
        </div>
        """
    )

    with gr.Row():
        # --- Coluna da Esquerda: Entradas e Controles ---
        with gr.Column(scale=1):
            gr.Markdown("### 1. Upload Media")
            input_media = gr.File(label="Input File (Video or Image)", type="filepath", interactive=True)
            
            gr.Markdown("### 2. Configure Settings")
            with gr.Accordion("Generation Parameters", open=True):
                resolution_select = gr.Dropdown(
                    label="Resolution",
                    choices=["480", "560", "720", "960", "1024", "2048"],
                    value="480",
                    info="Sets the output height and width to this value."
                )
                
                sp_size_slider = gr.Slider(
                    label="Frames per Batch (sp_size)",
                    minimum=1, maximum=16, step=4, value=8,
                    info="For multi-GPU videos. Automatically set to 1 for images."
                )

                fps_out = gr.Number(label="Output FPS (for Videos)", value=24, precision=0, info="Set to 0 to use the original FPS.")

            run_button = gr.Button("Restore Media", variant="primary", icon="✨")

        # --- Coluna da Direita: Resultados ---
        with gr.Column(scale=2):
            gr.Markdown("### 3. Results")
            
            # Janela de Log
            log_window = gr.Textbox(
                label="Inference Log 📝",
                lines=8, max_lines=15,
                interactive=False, visible=False, autoscroll=True
            )

            # Componentes de saída (começam invisíveis)
            output_image = gr.Image(label="Image Result", show_download_button=True, type="filepath", visible=False)
            output_video = gr.Video(label="Video Result", visible=False)
            output_download = gr.File(label="Download Full Result (Video)", visible=False)

    # --- Rodapé ---
    gr.Markdown(
        """
        ---
        *Space and Docker were developed by Carlex.*
        *Contact: Email: Carlex22@gmail.com | GitHub: [carlex22](https://github.com/carlex22)*
        """
    )
    
    # --- Lógica de Eventos da UI ---
    
    # Ao fazer upload de um arquivo, ajusta o slider `sp_size` e limpa saídas antigas.
    input_media.upload(
        fn=on_file_upload,
        inputs=[input_media],
        outputs=[sp_size_slider, output_image, output_video, output_download, log_window]
    )
    
    # Ao clicar no botão, executa a função de inferência principal.
    run_button.click(
        fn=run_inference_ui,
        inputs=[input_media, resolution_select, sp_size_slider, fps_out],
        outputs=[run_button, output_image, output_video, output_download, log_window],
    )

if __name__ == "__main__":
    demo.launch(
        server_name=os.getenv("GRADIO_SERVER_NAME", "0.0.0.0"),
        server_port=int(os.getenv("GRADIO_SERVER_PORT", "7860")),
        show_error=True
    )