Test2 / app_seedvr.py
EuuIia's picture
Update app_seedvr.py
296cade verified
raw
history blame
8.11 kB
# app_seedvr.py
import os
import sys
from pathlib import Path
from typing import Optional
import gradio as gr
import cv2
import multiprocessing as mp # <--- LINHA ADICIONADA AQUI
# --- INTEGRAÇÃO COM A LÓGICA DO SERVIDOR ---
try:
from api.seedvr_server import SeedVRServer
except ImportError as e:
print(f"ERRO FATAL: Não foi possível importar o SeedVRServer. Detalhes: {e}")
raise
# --- INICIALIZAÇÃO ---
server = SeedVRServer()
# --- FUNÇÕES AUXILIARES DA UI ---
def _is_video(path: str) -> bool:
"""Verifica se um caminho de arquivo corresponde a um tipo de vídeo."""
if not path: return False
import mimetypes
mime, _ = mimetypes.guess_type(path)
return (mime or "").startswith("video")
def _extract_first_frame(video_path: str) -> Optional[str]:
"""Extrai o primeiro frame de um vídeo e o salva como uma imagem JPG."""
if not video_path or not os.path.exists(video_path): return None
try:
vid_cap = cv2.VideoCapture(video_path)
if not vid_cap.isOpened():
print(f"Erro: Não foi possível abrir o vídeo em {video_path}")
return None
success, image = vid_cap.read()
vid_cap.release()
if not success:
print(f"Erro: Não foi possível ler o primeiro frame de {video_path}")
return None
image_path = Path(video_path).with_suffix(".jpg")
cv2.imwrite(str(image_path), image)
return str(image_path)
except Exception as e:
print(f"Erro ao extrair o primeiro frame: {e}")
return None
def on_file_upload(file_obj):
"""Callback acionado quando o usuário faz o upload de um arquivo."""
if file_obj is None:
return gr.update(value=1), None, None, None, gr.update(value=None, visible=False)
if _is_video(file_obj.name):
return gr.update(value=4, interactive=True), None, None, None, gr.update(value=None, visible=False)
else:
return gr.update(value=1, interactive=False), None, None, None, gr.update(value=None, visible=False)
# --- FUNÇÃO PRINCIPAL DE INFERÊNCIA DA UI ---
def run_inference_ui(
input_file_path: Optional[str],
resolution: str,
sp_size: int,
fps: float,
progress=gr.Progress(track_tqdm=True)
):
"""
A função de callback principal do Gradio.
"""
yield (
gr.update(interactive=False, value="Processing... 🚀"),
gr.update(value=None, visible=False),
gr.update(value=None, visible=False),
gr.update(value=None, visible=False),
gr.update(value="▶ Starting inference process...\n", visible=True)
)
if not input_file_path:
gr.Warning("Please upload a media file first.")
yield (gr.update(interactive=True, value="Restore Media"), None, None, None, gr.update(visible=False))
return
log_buffer = ["▶ Starting inference process...\n"]
last_log_message = ""
was_input_video = _is_video(input_file_path)
try:
def progress_callback_wrapper(step: float, desc: str):
nonlocal last_log_message
if desc != last_log_message:
log_buffer.append(f"{desc}\n")
last_log_message = desc
progress(step, desc=desc)
video_result_path = server.run_inference(
file_path=input_file_path,
seed=42,
res_h=int(resolution),
res_w=int(resolution),
sp_size=int(sp_size),
fps=float(fps) if fps and fps > 0 else None,
progress=progress_callback_wrapper,
)
progress(1.0, desc="Complete!")
log_buffer.append("✅ Inference complete! Processing final output...\n")
final_image, final_video = None, None
if was_input_video:
final_video = video_result_path
log_buffer.append("✅ Video result is ready.\n")
else:
final_image = _extract_first_frame(video_result_path)
final_video = video_result_path
log_buffer.append("✅ Image result extracted from video.\n")
yield (
gr.update(interactive=True, value="Restore Media"),
gr.update(value=final_image, visible=final_image is not None),
gr.update(value=final_video, visible=final_video is not None),
gr.update(value=video_result_path, visible=video_result_path is not None),
''.join(log_buffer)
)
except Exception as e:
error_message = f"❌ Inference failed: {e}"
gr.Error(error_message)
log_buffer.append(f"\n{error_message}")
import traceback
traceback.print_exc()
yield (
gr.update(interactive=True, value="Restore Media"),
None, None, None,
gr.update(value=''.join(log_buffer), visible=True)
)
# --- LAYOUT DA INTERFACE GRÁFICA (GRADIO) ---
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue"), title="SeedVR Media Restoration") as demo:
# ... (O layout da UI permanece exatamente o mesmo)
gr.Markdown(
"""
<div style='text-align: center; margin-bottom: 20px;'>
<h1>📸 SeedVR - Image & Video Restoration 🚀</h1>
<p>High-quality media upscaling powered by SeedVR-3B. Upload your file and see the magic.</p>
</div>
"""
)
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### 1. Upload Media")
input_media = gr.File(label="Input File (Video or Image)", type="filepath", interactive=True)
gr.Markdown("### 2. Configure Settings")
with gr.Accordion("Generation Parameters", open=True):
resolution_select = gr.Dropdown(
label="Resolution",
choices=["480", "560", "720", "960", "1024"],
value="480",
info="Sets the output height and width to this value."
)
sp_size_slider = gr.Slider(
label="Frames per Batch (sp_size)",
minimum=1, maximum=16, step=1, value=4,
info="For multi-GPU videos. Automatically set to 1 for images."
)
fps_out = gr.Number(label="Output FPS (for Videos)", value=24, precision=0, info="Set to 0 to use the original FPS.")
run_button = gr.Button("Restore Media", variant="primary", icon="✨")
with gr.Column(scale=2):
gr.Markdown("### 3. Results")
log_window = gr.Textbox(
label="Inference Log 📝", lines=8, max_lines=15,
interactive=False, visible=False, autoscroll=True,
)
output_image = gr.Image(label="Image Result", show_download_button=True, type="filepath", visible=False)
output_video = gr.Video(label="Video Result", visible=False)
output_download = gr.File(label="Download Full Result (Video)", visible=False)
gr.Markdown(
"""
---
*Space and Docker were developed by Carlex.*
*Contact: Email: Carlex22@gmail.com | GitHub: [carlex22](https://github.com/carlex22)*
"""
)
input_media.upload(
fn=on_file_upload,
inputs=[input_media],
outputs=[sp_size_slider, output_image, output_video, output_download, log_window]
)
run_button.click(
fn=run_inference_ui,
inputs=[input_media, resolution_select, sp_size_slider, fps_out],
outputs=[run_button, output_image, output_video, output_download, log_window],
)
if __name__ == "__main__":
# Garante que o start_method do multiprocessing seja 'spawn', que é mais seguro
# e evita problemas de estado compartilhado entre processos.
# É uma boa prática definir isso no ponto de entrada principal da aplicação.
mp.set_start_method('spawn', force=True)
demo.launch(
server_name=os.getenv("GRADIO_SERVER_NAME", "0.0.0.0"),
server_port=int(os.getenv("GRADIO_SERVER_PORT", "7860")),
show_error=True
)