Test2 / app_seedvr.py
EuuIia's picture
Update app_seedvr.py
3e978e1 verified
raw
history blame
5.6 kB
# app_seedvr.py
import os
from pathlib import Path
from typing import Optional
import gradio as gr
import cv2
try:
# Importa a classe de servidor que agora é uma biblioteca local
from api.seedvr_server import SeedVRServer
except ImportError as e:
print(f"ERRO FATAL: Não foi possível importar o SeedVRServer. Detalhes: {e}")
# Se a importação falhar, a aplicação não pode continuar.
raise
# Cria uma instância única do servidor. A inicialização (clonar repo, baixar modelos) acontece aqui.
server = SeedVRServer()
def _is_video(path: str) -> bool:
"""Verifica se um caminho de arquivo corresponde a um tipo de vídeo."""
if not path: return False
import mimetypes
mime, _ = mimetypes.guess_type(path)
return (mime or "").startswith("video")
def _extract_first_frame(video_path: str) -> Optional[str]:
"""Extrai o primeiro frame de um vídeo e o salva como uma imagem JPG."""
if not video_path or not os.path.exists(video_path): return None
try:
vid_cap = cv2.VideoCapture(video_path)
if not vid_cap.isOpened(): return None
success, image = vid_cap.read()
vid_cap.release()
if not success: return None
# Salva o frame no mesmo diretório do vídeo, com extensão .jpg
image_path = Path(video_path).with_suffix(".jpg")
cv2.imwrite(str(image_path), image)
return str(image_path)
except Exception as e:
print(f"Erro ao extrair o primeiro frame: {e}")
return None
def ui_infer(
input_path: Optional[str],
seed: int, res_h: int, res_w: int,
sp_size: int, fps: float,
progress=gr.Progress(track_tqdm=True)
):
"""
Função de callback principal do Gradio. Agora chama a lógica de inferência diretamente.
"""
if not input_path:
gr.Warning("Por favor, faça o upload de um arquivo.")
return None, None, None
was_input_video = _is_video(input_path)
try:
# Desabilita o botão enquanto processa
yield gr.update(interactive=False, value="Processando..."), None, None, None
# Chama o método direto do servidor, passando o objeto de progresso do Gradio
video_result_path = server.run_inference_direct(
file_path=input_path,
seed=int(seed),
res_h=int(res_h),
res_w=int(res_w),
sp_size=int(sp_size),
fps=float(fps) if fps and fps > 0 else None,
progress=progress,
)
progress(1.0, desc="Concluído!")
final_image, final_video = None, None
if was_input_video:
final_video = video_result_path
else: # Se a entrada foi uma imagem
final_image = _extract_first_frame(video_result_path)
final_video = video_result_path
# Retorna o resultado e reabilita o botão
yield (
gr.update(interactive=True, value="Restaurar Mídia"),
gr.update(value=final_image, visible=final_image is not None),
gr.update(value=final_video, visible=final_video is not None),
gr.update(value=video_result_path, visible=video_result_path is not None)
)
except Exception as e:
error_message = f"A inferência falhou: {e}"
gr.Error(error_message)
print(error_message)
import traceback
traceback.print_exc()
# Limpa os resultados e reabilita o botão em caso de erro
yield gr.update(interactive=True, value="Restaurar Mídia"), None, None, None
# --- Construção da Interface Gráfica ---
with gr.Blocks(title="SeedVR (Aduc-SDR)", theme=gr.themes.Soft()) as demo:
gr.HTML("""
<div style='text-align:center; margin-bottom: 20px;'>
<h1>SeedVR - Restauração de Imagem e Vídeo</h1>
<p>Implementação com backend Aduc-SDR</p>
</div>
""")
with gr.Row():
with gr.Column(scale=1):
inp = gr.File(label="Arquivo de Entrada (Vídeo .mp4 ou Imagem)", type="filepath")
with gr.Accordion("Parâmetros de Geração", open=True):
with gr.Row():
seed = gr.Number(label="Seed", value=42, precision=0)
fps_out = gr.Number(label="FPS de Saída (para Vídeos)", value=24, precision=0, info="0 para usar o FPS original.")
with gr.Row():
res_h = gr.Number(label="Altura (Height)", value=720, precision=0)
res_w = gr.Number(label="Largura (Width)", value=1280, precision=0)
sp_size = gr.Slider(label="Paralelismo de Sequência (sp_size)", minimum=1, maximum=160, step=4, value=4, info="Para vídeos em multi-GPU. Use 1 para imagens.")
run_button = gr.Button("Restaurar Mídia", variant="primary")
with gr.Column(scale=2):
gr.Markdown("### Resultado")
out_image = gr.Image(label="Resultado (Imagem)", show_download_button=True, type="filepath", visible=True)
out_video = gr.Video(label="Resultado (Vídeo)")
out_download = gr.File(label="Baixar Resultado (Vídeo)")
# A função click agora é um gerador.
run_button.click(
fn=ui_infer,
inputs=[inp, seed, res_h, res_w, sp_size, fps_out],
outputs=[run_button, out_image, out_video, out_download],
)
if __name__ == "__main__":
demo.launch(
server_name=os.getenv("GRADIO_SERVER_NAME", "0.0.0.0"),
server_port=int(os.getenv("GRADIO_SERVER_PORT", "7860")),
show_error=True
)