EuuIia commited on
Commit
162353e
·
verified ·
1 Parent(s): 9ad5b00

Update app_seedvr.py

Browse files
Files changed (1) hide show
  1. app_seedvr.py +60 -21
app_seedvr.py CHANGED
@@ -6,19 +6,22 @@ from pathlib import Path
6
  from typing import Optional
7
  import gradio as gr
8
  import cv2
9
- import multiprocessing as mp # <--- LINHA ADICIONADA AQUI
10
 
11
  # --- INTEGRAÇÃO COM A LÓGICA DO SERVIDOR ---
12
  try:
 
13
  from api.seedvr_server import SeedVRServer
14
  except ImportError as e:
15
  print(f"ERRO FATAL: Não foi possível importar o SeedVRServer. Detalhes: {e}")
 
16
  raise
17
 
18
  # --- INICIALIZAÇÃO ---
 
 
19
  server = SeedVRServer()
20
 
21
- # --- FUNÇÕES AUXILIARES DA UI ---
22
 
23
  def _is_video(path: str) -> bool:
24
  """Verifica se um caminho de arquivo corresponde a um tipo de vídeo."""
@@ -40,6 +43,8 @@ def _extract_first_frame(video_path: str) -> Optional[str]:
40
  if not success:
41
  print(f"Erro: Não foi possível ler o primeiro frame de {video_path}")
42
  return None
 
 
43
  image_path = Path(video_path).with_suffix(".jpg")
44
  cv2.imwrite(str(image_path), image)
45
  return str(image_path)
@@ -48,13 +53,19 @@ def _extract_first_frame(video_path: str) -> Optional[str]:
48
  return None
49
 
50
  def on_file_upload(file_obj):
51
- """Callback acionado quando o usuário faz o upload de um arquivo."""
 
 
 
52
  if file_obj is None:
 
53
  return gr.update(value=1), None, None, None, gr.update(value=None, visible=False)
54
 
55
  if _is_video(file_obj.name):
56
- return gr.update(value=4, interactive=True), None, None, None, gr.update(value=None, visible=False)
 
57
  else:
 
58
  return gr.update(value=1, interactive=False), None, None, None, gr.update(value=None, visible=False)
59
 
60
  # --- FUNÇÃO PRINCIPAL DE INFERÊNCIA DA UI ---
@@ -67,8 +78,11 @@ def run_inference_ui(
67
  progress=gr.Progress(track_tqdm=True)
68
  ):
69
  """
70
- A função de callback principal do Gradio.
 
71
  """
 
 
72
  yield (
73
  gr.update(interactive=False, value="Processing... 🚀"),
74
  gr.update(value=None, visible=False),
@@ -79,6 +93,7 @@ def run_inference_ui(
79
 
80
  if not input_file_path:
81
  gr.Warning("Please upload a media file first.")
 
82
  yield (gr.update(interactive=True, value="Restore Media"), None, None, None, gr.update(visible=False))
83
  return
84
 
@@ -87,35 +102,43 @@ def run_inference_ui(
87
  was_input_video = _is_video(input_file_path)
88
 
89
  try:
 
90
  def progress_callback_wrapper(step: float, desc: str):
 
91
  nonlocal last_log_message
 
92
  if desc != last_log_message:
93
  log_buffer.append(f"{desc}\n")
94
  last_log_message = desc
 
95
  progress(step, desc=desc)
96
 
97
- video_result_path = server.run_inference(
 
 
98
  file_path=input_file_path,
99
- seed=42,
100
  res_h=int(resolution),
101
- res_w=int(resolution),
102
  sp_size=int(sp_size),
103
  fps=float(fps) if fps and fps > 0 else None,
104
- progress=progress_callback_wrapper,
105
  )
106
 
107
  progress(1.0, desc="Complete!")
108
  log_buffer.append("✅ Inference complete! Processing final output...\n")
109
 
 
110
  final_image, final_video = None, None
111
  if was_input_video:
112
  final_video = video_result_path
113
  log_buffer.append("✅ Video result is ready.\n")
114
- else:
115
  final_image = _extract_first_frame(video_result_path)
116
- final_video = video_result_path
117
  log_buffer.append("✅ Image result extracted from video.\n")
118
 
 
119
  yield (
120
  gr.update(interactive=True, value="Restore Media"),
121
  gr.update(value=final_image, visible=final_image is not None),
@@ -131,6 +154,7 @@ def run_inference_ui(
131
  import traceback
132
  traceback.print_exc()
133
 
 
134
  yield (
135
  gr.update(interactive=True, value="Restore Media"),
136
  None, None, None,
@@ -138,8 +162,9 @@ def run_inference_ui(
138
  )
139
 
140
  # --- LAYOUT DA INTERFACE GRÁFICA (GRADIO) ---
 
141
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue"), title="SeedVR Media Restoration") as demo:
142
- # ... (O layout da UI permanece exatamente o mesmo)
143
  gr.Markdown(
144
  """
145
  <div style='text-align: center; margin-bottom: 20px;'>
@@ -148,34 +173,49 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue"), title="SeedVR Media Res
148
  </div>
149
  """
150
  )
 
151
  with gr.Row():
 
152
  with gr.Column(scale=1):
153
  gr.Markdown("### 1. Upload Media")
154
  input_media = gr.File(label="Input File (Video or Image)", type="filepath", interactive=True)
 
155
  gr.Markdown("### 2. Configure Settings")
156
  with gr.Accordion("Generation Parameters", open=True):
157
  resolution_select = gr.Dropdown(
158
  label="Resolution",
159
- choices=["480", "560", "720", "960", "1024"],
160
  value="480",
161
  info="Sets the output height and width to this value."
162
  )
 
163
  sp_size_slider = gr.Slider(
164
  label="Frames per Batch (sp_size)",
165
- minimum=1, maximum=16, step=1, value=4,
166
  info="For multi-GPU videos. Automatically set to 1 for images."
167
  )
 
168
  fps_out = gr.Number(label="Output FPS (for Videos)", value=24, precision=0, info="Set to 0 to use the original FPS.")
 
169
  run_button = gr.Button("Restore Media", variant="primary", icon="✨")
 
 
170
  with gr.Column(scale=2):
171
  gr.Markdown("### 3. Results")
 
 
172
  log_window = gr.Textbox(
173
- label="Inference Log 📝", lines=8, max_lines=15,
174
- interactive=False, visible=False, autoscroll=True,
 
175
  )
 
 
176
  output_image = gr.Image(label="Image Result", show_download_button=True, type="filepath", visible=False)
177
  output_video = gr.Video(label="Video Result", visible=False)
178
  output_download = gr.File(label="Download Full Result (Video)", visible=False)
 
 
179
  gr.Markdown(
180
  """
181
  ---
@@ -184,12 +224,16 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue"), title="SeedVR Media Res
184
  """
185
  )
186
 
 
 
 
187
  input_media.upload(
188
  fn=on_file_upload,
189
  inputs=[input_media],
190
  outputs=[sp_size_slider, output_image, output_video, output_download, log_window]
191
  )
192
 
 
193
  run_button.click(
194
  fn=run_inference_ui,
195
  inputs=[input_media, resolution_select, sp_size_slider, fps_out],
@@ -197,11 +241,6 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue"), title="SeedVR Media Res
197
  )
198
 
199
  if __name__ == "__main__":
200
- # Garante que o start_method do multiprocessing seja 'spawn', que é mais seguro
201
- # e evita problemas de estado compartilhado entre processos.
202
- # É uma boa prática definir isso no ponto de entrada principal da aplicação.
203
- mp.set_start_method('spawn', force=True)
204
-
205
  demo.launch(
206
  server_name=os.getenv("GRADIO_SERVER_NAME", "0.0.0.0"),
207
  server_port=int(os.getenv("GRADIO_SERVER_PORT", "7860")),
 
6
  from typing import Optional
7
  import gradio as gr
8
  import cv2
 
9
 
10
  # --- INTEGRAÇÃO COM A LÓGICA DO SERVIDOR ---
11
  try:
12
+ # Importa a classe SeedVRServer que agora atua como nossa biblioteca de inferência.
13
  from api.seedvr_server import SeedVRServer
14
  except ImportError as e:
15
  print(f"ERRO FATAL: Não foi possível importar o SeedVRServer. Detalhes: {e}")
16
+ # A aplicação não pode rodar sem a lógica do servidor.
17
  raise
18
 
19
  # --- INICIALIZAÇÃO ---
20
+ # Cria uma instância única e persistente do servidor.
21
+ # A inicialização (clonar repo, baixar modelos) acontece apenas uma vez, no início.
22
  server = SeedVRServer()
23
 
24
+ # --- FUNÇÕES AUXILIARES ---
25
 
26
  def _is_video(path: str) -> bool:
27
  """Verifica se um caminho de arquivo corresponde a um tipo de vídeo."""
 
43
  if not success:
44
  print(f"Erro: Não foi possível ler o primeiro frame de {video_path}")
45
  return None
46
+
47
+ # Salva o frame no mesmo diretório do vídeo, com extensão .jpg
48
  image_path = Path(video_path).with_suffix(".jpg")
49
  cv2.imwrite(str(image_path), image)
50
  return str(image_path)
 
53
  return None
54
 
55
  def on_file_upload(file_obj):
56
+ """
57
+ Callback acionado quando o usuário faz o upload de um arquivo.
58
+ Verifica se o arquivo é um vídeo e sugere um `sp_size` apropriado.
59
+ """
60
  if file_obj is None:
61
+ # Limpa os resultados e o log se o arquivo for removido
62
  return gr.update(value=1), None, None, None, gr.update(value=None, visible=False)
63
 
64
  if _is_video(file_obj.name):
65
+ # Para vídeos, sugere um valor padrão para multi-GPU e torna o slider interativo
66
+ return gr.update(value=8, interactive=True), None, None, None, gr.update(value=None, visible=False)
67
  else:
68
+ # Para imagens, trava o valor em 1
69
  return gr.update(value=1, interactive=False), None, None, None, gr.update(value=None, visible=False)
70
 
71
  # --- FUNÇÃO PRINCIPAL DE INFERÊNCIA DA UI ---
 
78
  progress=gr.Progress(track_tqdm=True)
79
  ):
80
  """
81
+ A função de callback principal do Gradio. Usa geradores (`yield`)
82
+ para permitir atualizações da UI em tempo real durante a tarefa de longa duração.
83
  """
84
+ # 1. Estado Inicial e Validação
85
+ # No início, desabilita o botão, limpa resultados anteriores e mostra a janela de log.
86
  yield (
87
  gr.update(interactive=False, value="Processing... 🚀"),
88
  gr.update(value=None, visible=False),
 
93
 
94
  if not input_file_path:
95
  gr.Warning("Please upload a media file first.")
96
+ # Reabilita o botão e esconde os componentes de saída
97
  yield (gr.update(interactive=True, value="Restore Media"), None, None, None, gr.update(visible=False))
98
  return
99
 
 
102
  was_input_video = _is_video(input_file_path)
103
 
104
  try:
105
+ # Define um callback que será chamado pelo backend para atualizar o progresso e o log
106
  def progress_callback_wrapper(step: float, desc: str):
107
+ """ Wrapper para formatar logs e atualizar o progresso. """
108
  nonlocal last_log_message
109
+ # Só adiciona ao log se a mensagem for nova, para evitar poluição visual
110
  if desc != last_log_message:
111
  log_buffer.append(f"{desc}\n")
112
  last_log_message = desc
113
+ # Atualiza o objeto de progresso do Gradio
114
  progress(step, desc=desc)
115
 
116
+ # 2. Executa a Inferência
117
+ # Chama o método direto do servidor, passando o nosso callback.
118
+ video_result_path = server.run_inference_direct(
119
  file_path=input_file_path,
120
+ seed=42, # Semente fixa conforme solicitado
121
  res_h=int(resolution),
122
+ res_w=int(resolution), # Largura igual à altura
123
  sp_size=int(sp_size),
124
  fps=float(fps) if fps and fps > 0 else None,
125
+ progress=progress_callback_wrapper, # Passa nossa função de callback
126
  )
127
 
128
  progress(1.0, desc="Complete!")
129
  log_buffer.append("✅ Inference complete! Processing final output...\n")
130
 
131
+ # 3. Processa e Exibe os Resultados
132
  final_image, final_video = None, None
133
  if was_input_video:
134
  final_video = video_result_path
135
  log_buffer.append("✅ Video result is ready.\n")
136
+ else: # Se a entrada foi uma imagem
137
  final_image = _extract_first_frame(video_result_path)
138
+ final_video = video_result_path # Também disponibiliza o vídeo de 1 frame
139
  log_buffer.append("✅ Image result extracted from video.\n")
140
 
141
+ # Yield final para mostrar os resultados e reabilitar o botão
142
  yield (
143
  gr.update(interactive=True, value="Restore Media"),
144
  gr.update(value=final_image, visible=final_image is not None),
 
154
  import traceback
155
  traceback.print_exc()
156
 
157
+ # Yield para estado de erro: reabilita o botão e mostra o log com o erro
158
  yield (
159
  gr.update(interactive=True, value="Restore Media"),
160
  None, None, None,
 
162
  )
163
 
164
  # --- LAYOUT DA INTERFACE GRÁFICA (GRADIO) ---
165
+
166
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue"), title="SeedVR Media Restoration") as demo:
167
+ # Cabeçalho
168
  gr.Markdown(
169
  """
170
  <div style='text-align: center; margin-bottom: 20px;'>
 
173
  </div>
174
  """
175
  )
176
+
177
  with gr.Row():
178
+ # --- Coluna da Esquerda: Entradas e Controles ---
179
  with gr.Column(scale=1):
180
  gr.Markdown("### 1. Upload Media")
181
  input_media = gr.File(label="Input File (Video or Image)", type="filepath", interactive=True)
182
+
183
  gr.Markdown("### 2. Configure Settings")
184
  with gr.Accordion("Generation Parameters", open=True):
185
  resolution_select = gr.Dropdown(
186
  label="Resolution",
187
+ choices=["480", "560", "720", "960", "1024", "2048"],
188
  value="480",
189
  info="Sets the output height and width to this value."
190
  )
191
+
192
  sp_size_slider = gr.Slider(
193
  label="Frames per Batch (sp_size)",
194
+ minimum=1, maximum=16, step=4, value=8,
195
  info="For multi-GPU videos. Automatically set to 1 for images."
196
  )
197
+
198
  fps_out = gr.Number(label="Output FPS (for Videos)", value=24, precision=0, info="Set to 0 to use the original FPS.")
199
+
200
  run_button = gr.Button("Restore Media", variant="primary", icon="✨")
201
+
202
+ # --- Coluna da Direita: Resultados ---
203
  with gr.Column(scale=2):
204
  gr.Markdown("### 3. Results")
205
+
206
+ # Janela de Log
207
  log_window = gr.Textbox(
208
+ label="Inference Log 📝",
209
+ lines=8, max_lines=15,
210
+ interactive=False, visible=False, autoscroll=True
211
  )
212
+
213
+ # Componentes de saída (começam invisíveis)
214
  output_image = gr.Image(label="Image Result", show_download_button=True, type="filepath", visible=False)
215
  output_video = gr.Video(label="Video Result", visible=False)
216
  output_download = gr.File(label="Download Full Result (Video)", visible=False)
217
+
218
+ # --- Rodapé ---
219
  gr.Markdown(
220
  """
221
  ---
 
224
  """
225
  )
226
 
227
+ # --- Lógica de Eventos da UI ---
228
+
229
+ # Ao fazer upload de um arquivo, ajusta o slider `sp_size` e limpa saídas antigas.
230
  input_media.upload(
231
  fn=on_file_upload,
232
  inputs=[input_media],
233
  outputs=[sp_size_slider, output_image, output_video, output_download, log_window]
234
  )
235
 
236
+ # Ao clicar no botão, executa a função de inferência principal.
237
  run_button.click(
238
  fn=run_inference_ui,
239
  inputs=[input_media, resolution_select, sp_size_slider, fps_out],
 
241
  )
242
 
243
  if __name__ == "__main__":
 
 
 
 
 
244
  demo.launch(
245
  server_name=os.getenv("GRADIO_SERVER_NAME", "0.0.0.0"),
246
  server_port=int(os.getenv("GRADIO_SERVER_PORT", "7860")),