Test2 / app_seedvr.py
EuuIia's picture
Update app_seedvr.py
d1eb45a verified
# app_seedvr.py
import os
import sys
from pathlib import Path
from typing import Optional
import gradio as gr
import cv2
# --- INTEGRAÇÃO COM A LÓGICA DO SERVIDOR ---
try:
# Importa a classe SeedVRServer que agora atua como nossa biblioteca de inferência.
from api.seedvr_server import SeedVRServer
except ImportError as e:
print(f"ERRO FATAL: Não foi possível importar o SeedVRServer. Detalhes: {e}")
# A aplicação não pode rodar sem a lógica do servidor.
raise
# --- INICIALIZAÇÃO ---
# Cria uma instância única e persistente do servidor.
# A inicialização (clonar repo, baixar modelos) acontece apenas uma vez, no início.
server = SeedVRServer()
# --- FUNÇÕES AUXILIARES ---
def _is_video(path: str) -> bool:
"""Verifica se um caminho de arquivo corresponde a um tipo de vídeo."""
if not path: return False
import mimetypes
mime, _ = mimetypes.guess_type(path)
return (mime or "").startswith("video")
def _extract_first_frame(video_path: str) -> Optional[str]:
"""Extrai o primeiro frame de um vídeo e o salva como uma imagem JPG."""
if not video_path or not os.path.exists(video_path): return None
try:
vid_cap = cv2.VideoCapture(video_path)
if not vid_cap.isOpened():
print(f"Erro: Não foi possível abrir o vídeo em {video_path}")
return None
success, image = vid_cap.read()
vid_cap.release()
if not success:
print(f"Erro: Não foi possível ler o primeiro frame de {video_path}")
return None
# Salva o frame no mesmo diretório do vídeo, com extensão .jpg
image_path = Path(video_path).with_suffix(".jpg")
cv2.imwrite(str(image_path), image)
return str(image_path)
except Exception as e:
print(f"Erro ao extrair o primeiro frame: {e}")
return None
def on_file_upload(file_obj):
"""
Callback acionado quando o usuário faz o upload de um arquivo.
Verifica se o arquivo é um vídeo e sugere um `sp_size` apropriado.
"""
if file_obj is None:
# Limpa os resultados e o log se o arquivo for removido
return gr.update(value=1), None, None, None, gr.update(value=None, visible=False)
if _is_video(file_obj.name):
# Para vídeos, sugere um valor padrão para multi-GPU e torna o slider interativo
return gr.update(value=8, interactive=True), None, None, None, gr.update(value=None, visible=False)
else:
# Para imagens, trava o valor em 1
return gr.update(value=1, interactive=False), None, None, None, gr.update(value=None, visible=False)
# --- FUNÇÃO PRINCIPAL DE INFERÊNCIA DA UI ---
def run_inference_ui(
input_file_path: Optional[str],
resolution: str,
sp_size: int,
fps: float,
progress=gr.Progress(track_tqdm=True)
):
"""
A função de callback principal do Gradio. Usa geradores (`yield`)
para permitir atualizações da UI em tempo real durante a tarefa de longa duração.
"""
# 1. Estado Inicial e Validação
# No início, desabilita o botão, limpa resultados anteriores e mostra a janela de log.
yield (
gr.update(interactive=False, value="Processing... 🚀"),
gr.update(value=None, visible=False),
gr.update(value=None, visible=False),
gr.update(value=None, visible=False),
gr.update(value="▶ Starting inference process...\n", visible=True)
)
if not input_file_path:
gr.Warning("Please upload a media file first.")
# Reabilita o botão e esconde os componentes de saída
yield (gr.update(interactive=True, value="Restore Media"), None, None, None, gr.update(visible=False))
return
log_buffer = ["▶ Starting inference process...\n"]
last_log_message = ""
was_input_video = _is_video(input_file_path)
try:
# Define um callback que será chamado pelo backend para atualizar o progresso e o log
def progress_callback_wrapper(step: float, desc: str):
""" Wrapper para formatar logs e atualizar o progresso. """
nonlocal last_log_message
# Só adiciona ao log se a mensagem for nova, para evitar poluição visual
if desc != last_log_message:
log_buffer.append(f"{desc}\n")
last_log_message = desc
# Atualiza o objeto de progresso do Gradio
progress(step, desc=desc)
# 2. Executa a Inferência
# Chama o método direto do servidor, passando o nosso callback.
video_result_path = server.run_inference_direct(
file_path=input_file_path,
seed=42, # Semente fixa conforme solicitado
res_h=int(resolution),
res_w=int(resolution), # Largura igual à altura
sp_size=int(sp_size),
fps=float(fps) if fps and fps > 0 else None,
progress=progress_callback_wrapper, # Passa nossa função de callback
)
progress(1.0, desc="Complete!")
log_buffer.append("✅ Inference complete! Processing final output...\n")
# 3. Processa e Exibe os Resultados
final_image, final_video = None, None
if was_input_video:
final_video = video_result_path
log_buffer.append("✅ Video result is ready.\n")
else: # Se a entrada foi uma imagem
final_image = _extract_first_frame(video_result_path)
final_video = video_result_path # Também disponibiliza o vídeo de 1 frame
log_buffer.append("✅ Image result extracted from video.\n")
# Yield final para mostrar os resultados e reabilitar o botão
yield (
gr.update(interactive=True, value="Restore Media"),
gr.update(value=final_image, visible=final_image is not None),
gr.update(value=final_video, visible=final_video is not None),
gr.update(value=video_result_path, visible=video_result_path is not None),
''.join(log_buffer)
)
except Exception as e:
error_message = f"❌ Inference failed: {e}"
gr.Error(error_message)
log_buffer.append(f"\n{error_message}")
import traceback
traceback.print_exc()
# Yield para estado de erro: reabilita o botão e mostra o log com o erro
yield (
gr.update(interactive=True, value="Restore Media"),
None, None, None,
gr.update(value=''.join(log_buffer), visible=True)
)
# --- LAYOUT DA INTERFACE GRÁFICA (GRADIO) ---
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue"), title="SeedVR Media Restoration") as demo:
# Cabeçalho
gr.Markdown(
"""
<div style='text-align: center; margin-bottom: 20px;'>
<h1>📸 SeedVR - Image & Video Restoration 🚀</h1>
<p>High-quality media upscaling powered by SeedVR-3B. Upload your file and see the magic.</p>
</div>
"""
)
with gr.Row():
# --- Coluna da Esquerda: Entradas e Controles ---
with gr.Column(scale=1):
gr.Markdown("### 1. Upload Media")
input_media = gr.File(label="Input File (Video or Image)", type="filepath", interactive=True)
gr.Markdown("### 2. Configure Settings")
with gr.Accordion("Generation Parameters", open=True):
resolution_select = gr.Dropdown(
label="Resolution",
choices=["480", "560", "720", "960", "1024", "2048"],
value="480",
info="Sets the output height and width to this value."
)
sp_size_slider = gr.Slider(
label="Frames per Batch (sp_size)",
minimum=1, maximum=16, step=4, value=8,
info="For multi-GPU videos. Automatically set to 1 for images."
)
fps_out = gr.Number(label="Output FPS (for Videos)", value=24, precision=0, info="Set to 0 to use the original FPS.")
run_button = gr.Button("Restore Media", variant="primary", icon="✨")
# --- Coluna da Direita: Resultados ---
with gr.Column(scale=2):
gr.Markdown("### 3. Results")
# Janela de Log
log_window = gr.Textbox(
label="Inference Log 📝",
lines=8, max_lines=15,
interactive=False, visible=False, autoscroll=True
)
# Componentes de saída (começam invisíveis)
output_image = gr.Image(label="Image Result", show_download_button=True, type="filepath", visible=False)
output_video = gr.Video(label="Video Result", visible=False)
output_download = gr.File(label="Download Full Result (Video)", visible=False)
# --- Rodapé ---
gr.Markdown(
"""
---
*Space and Docker were developed by Carlex.*
*Contact: Email: Carlex22@gmail.com | GitHub: [carlex22](https://github.com/carlex22)*
"""
)
# --- Lógica de Eventos da UI ---
# Ao fazer upload de um arquivo, ajusta o slider `sp_size` e limpa saídas antigas.
input_media.upload(
fn=on_file_upload,
inputs=[input_media],
outputs=[sp_size_slider, output_image, output_video, output_download, log_window]
)
# Ao clicar no botão, executa a função de inferência principal.
run_button.click(
fn=run_inference_ui,
inputs=[input_media, resolution_select, sp_size_slider, fps_out],
outputs=[run_button, output_image, output_video, output_download, log_window],
)
if __name__ == "__main__":
demo.launch(
server_name=os.getenv("GRADIO_SERVER_NAME", "0.0.0.0"),
server_port=int(os.getenv("GRADIO_SERVER_PORT", "7860")),
show_error=True
)