EuuIia commited on
Commit
98e3867
·
verified ·
1 Parent(s): 8e5d88b

Update api/seedvr_server.py

Browse files
Files changed (1) hide show
  1. api/seedvr_server.py +125 -231
api/seedvr_server.py CHANGED
@@ -1,248 +1,142 @@
1
- # app_seedvr.py
2
 
3
  import os
4
  import sys
 
 
 
 
5
  from pathlib import Path
6
- from typing import Optional
7
- import gradio as gr
8
- import cv2
9
 
10
- # --- INTEGRAÇÃO COM A LÓGICA DO SERVIDOR ---
11
- try:
12
- # Importa a classe SeedVRServer que agora atua como nossa biblioteca de inferência.
13
- from api.seedvr_server import SeedVRServer
14
- except ImportError as e:
15
- print(f"ERRO FATAL: Não foi possível importar o SeedVRServer. Detalhes: {e}")
16
- # A aplicação não pode rodar sem a lógica do servidor.
17
- raise
18
-
19
- # --- INICIALIZAÇÃO ---
20
- # Cria uma instância única e persistente do servidor.
21
- # A inicialização (clonar repo, baixar modelos) acontece apenas uma vez, no início.
22
- server = SeedVRServer()
23
-
24
- # --- FUNÇÕES AUXILIARES ---
25
-
26
- def _is_video(path: str) -> bool:
27
- """Verifica se um caminho de arquivo corresponde a um tipo de vídeo."""
28
- if not path: return False
29
- import mimetypes
30
- mime, _ = mimetypes.guess_type(path)
31
- return (mime or "").startswith("video")
32
-
33
- def _extract_first_frame(video_path: str) -> Optional[str]:
34
- """Extrai o primeiro frame de um vídeo e o salva como uma imagem JPG."""
35
- if not video_path or not os.path.exists(video_path): return None
36
- try:
37
- vid_cap = cv2.VideoCapture(video_path)
38
- if not vid_cap.isOpened():
39
- print(f"Erro: Não foi possível abrir o vídeo em {video_path}")
40
- return None
41
- success, image = vid_cap.read()
42
- vid_cap.release()
43
- if not success:
44
- print(f"Erro: Não foi possível ler o primeiro frame de {video_path}")
45
- return None
46
-
47
- # Salva o frame no mesmo diretório do vídeo, com extensão .jpg
48
- image_path = Path(video_path).with_suffix(".jpg")
49
- cv2.imwrite(str(image_path), image)
50
- return str(image_path)
51
- except Exception as e:
52
- print(f"Erro ao extrair o primeiro frame: {e}")
53
- return None
54
-
55
- def on_file_upload(file_obj):
56
- """
57
- Callback acionado quando o usuário faz o upload de um arquivo.
58
- Verifica se o arquivo é um vídeo e sugere um `sp_size` apropriado.
59
- """
60
- if file_obj is None:
61
- # Limpa os resultados e o log se o arquivo for removido
62
- return gr.update(value=1), None, None, None, gr.update(value=None, visible=False)
63
-
64
- if _is_video(file_obj.name):
65
- # Para vídeos, sugere um valor padrão para multi-GPU e torna o slider interativo
66
- return gr.update(value=4, interactive=True), None, None, None, gr.update(value=None, visible=False)
67
- else:
68
- # Para imagens, trava o valor em 1
69
- return gr.update(value=1, interactive=False), None, None, None, gr.update(value=None, visible=False)
70
-
71
- # --- FUNÇÃO PRINCIPAL DE INFERÊNCIA DA UI ---
72
 
73
- def run_inference_ui(
74
- input_file_path: Optional[str],
75
- resolution: str,
76
- sp_size: int,
77
- fps: float,
78
- progress=gr.Progress(track_tqdm=True)
79
- ):
80
- """
81
- A função de callback principal do Gradio. Usa geradores (`yield`)
82
- para permitir atualizações da UI em tempo real durante a tarefa de longa duração.
83
- """
84
- # 1. Estado Inicial e Validação
85
- # No início, desabilita o botão, limpa resultados anteriores e mostra a janela de log.
86
- yield (
87
- gr.update(interactive=False, value="Processing... 🚀"),
88
- gr.update(value=None, visible=False),
89
- gr.update(value=None, visible=False),
90
- gr.update(value=None, visible=False),
91
- gr.update(value="▶ Starting inference process...\n", visible=True)
92
- )
93
 
94
- if not input_file_path:
95
- gr.Warning("Please upload a media file first.")
96
- # Reabilita o botão e esconde os componentes de saída
97
- yield (gr.update(interactive=True, value="Restore Media"), None, None, None, gr.update(visible=False))
98
- return
99
-
100
- log_buffer = ["▶ Starting inference process...\n"]
101
- last_log_message = ""
102
- was_input_video = _is_video(input_file_path)
103
-
104
- try:
105
- # Define um callback que será chamado pelo backend para atualizar o progresso e o log
106
- def progress_callback_wrapper(step: float, desc: str):
107
- """ Wrapper para formatar logs e atualizar o progresso. """
108
- nonlocal last_log_message
109
- # Só adiciona ao log se a mensagem for nova, para evitar poluição visual
110
- if desc != last_log_message:
111
- log_buffer.append(f"⏳ {desc}\n")
112
- last_log_message = desc
113
- # Atualiza o objeto de progresso do Gradio
114
- progress(step, desc=desc)
115
 
116
- # 2. Executa a Inferência
117
- # Chama o método direto do servidor, passando o nosso callback.
118
- video_result_path = server.run_inference_direct(
119
- file_path=input_file_path,
120
- seed=42, # Semente fixa conforme solicitado
121
- res_h=int(resolution),
122
- res_w=int(resolution), # Largura igual à altura
123
- sp_size=int(sp_size),
124
- fps=float(fps) if fps and fps > 0 else None,
125
- progress=progress_callback_wrapper, # Passa nossa função de callback
126
- )
127
-
128
- progress(1.0, desc="Complete!")
129
- log_buffer.append("✅ Inference complete! Processing final output...\n")
130
 
131
- # 3. Processa e Exibe os Resultados
132
- final_image, final_video = None, None
133
- if was_input_video:
134
- final_video = video_result_path
135
- log_buffer.append("✅ Video result is ready.\n")
136
- else: # Se a entrada foi uma imagem
137
- final_image = _extract_first_frame(video_result_path)
138
- final_video = video_result_path # Também disponibiliza o vídeo de 1 frame
139
- log_buffer.append("✅ Image result extracted from video.\n")
140
 
141
- # Yield final para mostrar os resultados e reabilitar o botão
142
- yield (
143
- gr.update(interactive=True, value="Restore Media"),
144
- gr.update(value=final_image, visible=final_image is not None),
145
- gr.update(value=final_video, visible=final_video is not None),
146
- gr.update(value=video_result_path, visible=video_result_path is not None),
147
- ''.join(log_buffer)
148
- )
149
-
150
- except Exception as e:
151
- error_message = f"❌ Inference failed: {e}"
152
- gr.Error(error_message)
153
- log_buffer.append(f"\n{error_message}")
154
- import traceback
155
- traceback.print_exc()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
- # Yield para estado de erro: reabilita o botão e mostra o log com o erro
158
- yield (
159
- gr.update(interactive=True, value="Restore Media"),
160
- None, None, None,
161
- gr.update(value=''.join(log_buffer), visible=True)
162
- )
163
-
164
- # --- LAYOUT DA INTERFACE GRÁFICA (GRADIO) ---
165
-
166
- with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue"), title="SeedVR Media Restoration") as demo:
167
- # Cabeçalho
168
- gr.Markdown(
169
  """
170
- <div style='text-align: center; margin-bottom: 20px;'>
171
- <h1>📸 SeedVR - Image & Video Restoration 🚀</h1>
172
- <p>High-quality media upscaling powered by SeedVR-3B. Upload your file and see the magic.</p>
173
- </div>
174
  """
175
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
 
177
- with gr.Row():
178
- # --- Coluna da Esquerda: Entradas e Controles ---
179
- with gr.Column(scale=1):
180
- gr.Markdown("### 1. Upload Media")
181
- input_media = gr.File(label="Input File (Video or Image)", type="filepath", interactive=True)
182
 
183
- gr.Markdown("### 2. Configure Settings")
184
- with gr.Accordion("Generation Parameters", open=True):
185
- resolution_select = gr.Dropdown(
186
- label="Resolution",
187
- choices=["480", "560", "720", "960", "1024"],
188
- value="480",
189
- info="Sets the output height and width to this value."
190
- )
191
-
192
- sp_size_slider = gr.Slider(
193
- label="Frames per Batch (sp_size)",
194
- minimum=1, maximum=16, step=1, value=4,
195
- info="For multi-GPU videos. Automatically set to 1 for images."
196
- )
197
-
198
- fps_out = gr.Number(label="Output FPS (for Videos)", value=24, precision=0, info="Set to 0 to use the original FPS.")
199
-
200
- run_button = gr.Button("Restore Media", variant="primary", icon="✨")
201
-
202
- # --- Coluna da Direita: Resultados ---
203
- with gr.Column(scale=2):
204
- gr.Markdown("### 3. Results")
205
 
206
- # Janela de Log
207
- log_window = gr.Textbox(
208
- label="Inference Log 📝",
209
- lines=8, max_lines=15,
210
- interactive=False, visible=False, autoscroll=True
211
- )
212
-
213
- # Componentes de saída (começam invisíveis)
214
- output_image = gr.Image(label="Image Result", show_download_button=True, type="filepath", visible=False)
215
- output_video = gr.Video(label="Video Result", visible=False)
216
- output_download = gr.File(label="Download Full Result (Video)", visible=False)
217
 
218
- # --- Rodapé ---
219
- gr.Markdown(
220
- """
221
- ---
222
- *Space and Docker were developed by Carlex.*
223
- *Contact: Email: Carlex22@gmail.com | GitHub: [carlex22](https://github.com/carlex22)*
224
- """
225
- )
226
-
227
- # --- Lógica de Eventos da UI ---
228
-
229
- # Ao fazer upload de um arquivo, ajusta o slider `sp_size` e limpa saídas antigas.
230
- input_media.upload(
231
- fn=on_file_upload,
232
- inputs=[input_media],
233
- outputs=[sp_size_slider, output_image, output_video, output_download, log_window]
234
- )
235
-
236
- # Ao clicar no botão, executa a função de inferência principal.
237
- run_button.click(
238
- fn=run_inference_ui,
239
- inputs=[input_media, resolution_select, sp_size_slider, fps_out],
240
- outputs=[run_button, output_image, output_video, output_download, log_window],
241
- )
242
-
243
- if __name__ == "__main__":
244
- demo.launch(
245
- server_name=os.getenv("GRADIO_SERVER_NAME", "0.0.0.0"),
246
- server_port=int(os.getenv("GRADIO_SERVER_PORT", "7860")),
247
- show_error=True
248
- )
 
1
+ # api/seedvr_server.py
2
 
3
  import os
4
  import sys
5
+ import shutil
6
+ import mimetypes
7
+ import time
8
+ import subprocess # Necessário para clonar o repositório na configuração inicial
9
  from pathlib import Path
10
+ from typing import Optional, Callable
11
+ from types import SimpleNamespace
 
12
 
13
+ from huggingface_hub import hf_hub_download
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
+ # Adiciona dinamicamente o caminho do repositório clonado ao sys.path.
16
+ # Isso é crucial para que a importação do 'inference_cli' funcione.
17
+ SEEDVR_REPO_PATH = Path(os.getenv("SEEDVR_ROOT", "/data/SeedVR"))
18
+ if str(SEEDVR_REPO_PATH) not in sys.path:
19
+ # Insere no início da lista para garantir prioridade de importação.
20
+ sys.path.insert(0, str(SEEDVR_REPO_PATH))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
+ # Tenta importar as funções necessárias APÓS a modificação do path.
23
+ # Se falhar, a aplicação não pode continuar.
24
+ try:
25
+ from inference_cli import run_inference_logic, save_frames_to_video
26
+ except ImportError as e:
27
+ print(f"ERRO FATAL: Não foi possível importar de 'inference_cli.py'.")
28
+ print(f"Verifique se o repositório em '{SEEDVR_REPO_PATH}' está correto e completo.")
29
+ raise e
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
+ class SeedVRServer:
32
+ def __init__(self, **kwargs):
33
+ """
34
+ Inicializa o servidor, define os caminhos e prepara o ambiente.
35
+ """
36
+ self.SEEDVR_ROOT = SEEDVR_REPO_PATH
37
+ self.CKPTS_ROOT = Path("/data/seedvr_models_fp16")
38
+ self.OUTPUT_ROOT = Path(os.getenv("OUTPUT_ROOT", "/app/outputs"))
39
+ self.INPUT_ROOT = Path(os.getenv("INPUT_ROOT", "/app/inputs"))
40
+ self.HF_HOME_CACHE = Path(os.getenv("HF_HOME", "/data/.cache/huggingface"))
41
+ self.REPO_URL = os.getenv("SEEDVR_GIT_URL", "https://github.com/numz/ComfyUI-SeedVR2_VideoUpscaler")
42
+ self.NUM_GPUS_TOTAL = int(os.getenv("NUM_GPUS", "4"))
 
 
43
 
44
+ print("🚀 SeedVRServer (Modo de Chamada Direta) inicializando...")
45
+ for p in [self.CKPTS_ROOT, self.OUTPUT_ROOT, self.INPUT_ROOT, self.HF_HOME_CACHE]:
46
+ p.mkdir(parents=True, exist_ok=True)
 
 
 
 
 
 
47
 
48
+ self.setup_dependencies()
49
+ print("✅ SeedVRServer (Modo de Chamada Direta) pronto.")
50
+
51
+ def setup_dependencies(self):
52
+ """ Garante que o repositório e os modelos estão presentes. """
53
+ self._ensure_repo()
54
+ self._ensure_model()
55
+
56
+ def _ensure_repo(self) -> None:
57
+ """ Clona o repositório do SeedVR se ele não existir. """
58
+ if not (self.SEEDVR_ROOT / ".git").exists():
59
+ print(f"[SeedVRServer] Clonando repositório para {self.SEEDVR_ROOT}...")
60
+ # Usamos subprocess.run aqui porque é uma tarefa de inicialização única.
61
+ subprocess.run(["git", "clone", "--depth", "1", self.REPO_URL, str(self.SEEDVR_ROOT)], check=True)
62
+ else:
63
+ print("[SeedVRServer] Repositório SeedVR já existe.")
64
+
65
+ def _ensure_model(self) -> None:
66
+ """ Baixa os checkpoints do Hugging Face se não existirem localmente. """
67
+ print(f"[SeedVRServer] Verificando checkpoints (FP16) em {self.CKPTS_ROOT}...")
68
+ model_files = {
69
+ "seedvr2_ema_3b_fp16.safetensors": "MonsterMMORPG/SeedVR2_SECourses",
70
+ "ema_vae_fp16.safetensors": "MonsterMMORPG/SeedVR2_SECourses",
71
+ "pos_emb.pt": "ByteDance-Seed/SeedVR2-3B",
72
+ "neg_emb.pt": "ByteDance-Seed/SeedVR2-3B"
73
+ }
74
+ for filename, repo_id in model_files.items():
75
+ if not (self.CKPTS_ROOT / filename).exists():
76
+ print(f"Baixando {filename} de {repo_id}...")
77
+ hf_hub_download(
78
+ repo_id=repo_id, filename=filename, local_dir=str(self.CKPTS_ROOT),
79
+ cache_dir=str(self.HF_HOME_CACHE), token=os.getenv("HF_TOKEN")
80
+ )
81
+ print("[SeedVRServer] Checkpoints (FP16) estão no local correto.")
82
 
83
+ def run_inference_direct(
84
+ self,
85
+ file_path: str, *,
86
+ seed: int, res_h: int, res_w: int, sp_size: int,
87
+ fps: Optional[float] = None, progress: Optional[Callable] = None
88
+ ) -> str:
 
 
 
 
 
 
89
  """
90
+ Executa a inferência diretamente no mesmo processo e retorna o caminho do arquivo de saída.
 
 
 
91
  """
92
+ # Cria um diretório de saída único para salvar o resultado.
93
+ out_dir = self.OUTPUT_ROOT / f"run_{int(time.time())}_{Path(file_path).stem}"
94
+ out_dir.mkdir(parents=True, exist_ok=True)
95
+ output_filepath = out_dir / f"result_{Path(file_path).stem}.mp4"
96
+
97
+ # Simula o objeto 'args' que a função de lógica do inference_cli espera.
98
+ # Usamos SimpleNamespace para criar um objeto simples com atributos.
99
+ args = SimpleNamespace(
100
+ video_path=file_path,
101
+ output=str(output_filepath),
102
+ model_dir=str(self.CKPTS_ROOT),
103
+ seed=seed,
104
+ resolution=res_h, # O script do SeedVR usa a altura (lado menor) como referência.
105
+ batch_size=sp_size,
106
+ model="seedvr2_ema_3b_fp16.safetensors",
107
+ preserve_vram=True,
108
+ debug=True, # Mantém o debug ativo para logs detalhados.
109
+ cuda_device=",".join(map(str, range(self.NUM_GPUS_TOTAL))),
110
+ skip_first_frames=0,
111
+ load_cap=0,
112
+ output_format='video' # Garante que sempre gere vídeo
113
+ )
114
 
115
+ try:
116
+ # Informa a UI que o processo começou.
117
+ if progress:
118
+ progress(0.01, "Initializing...")
 
119
 
120
+ # Chama a função importada do script original, passando o callback de progresso.
121
+ # Este callback será chamado de dentro da lógica de multi-processamento.
122
+ result_tensor, original_fps, _, _ = run_inference_logic(args, progress_callback=progress)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
+ # Informa a UI que a inferência terminou e o salvamento vai começar.
125
+ if progress:
126
+ progress(0.95, "Saving the final video...")
 
 
 
 
 
 
 
 
127
 
128
+ # Define o FPS final: usa o valor da UI ou o original do vídeo de entrada.
129
+ final_fps = fps if fps and fps > 0 else original_fps
130
+ save_frames_to_video(result_tensor, str(output_filepath), final_fps, args.debug)
131
+
132
+ print(f"✅ Video saved successfully to: {output_filepath}")
133
+
134
+ # Retorna o caminho do arquivo gerado para a UI.
135
+ return str(output_filepath)
136
+
137
+ except Exception as e:
138
+ print(f"❌ Error during direct inference execution: {e}")
139
+ import traceback
140
+ traceback.print_exc()
141
+ # Propaga o erro para a UI do Gradio, que o exibirá de forma amigável.
142
+ raise