EuuIia commited on
Commit
296cade
·
verified ·
1 Parent(s): edd8fe7

Update app_seedvr.py

Browse files
Files changed (1) hide show
  1. app_seedvr.py +13 -54
app_seedvr.py CHANGED
@@ -6,19 +6,16 @@ from pathlib import Path
6
  from typing import Optional
7
  import gradio as gr
8
  import cv2
 
9
 
10
  # --- INTEGRAÇÃO COM A LÓGICA DO SERVIDOR ---
11
  try:
12
- # Importa a classe SeedVRServer que agora contém toda a lógica de inferência.
13
  from api.seedvr_server import SeedVRServer
14
  except ImportError as e:
15
  print(f"ERRO FATAL: Não foi possível importar o SeedVRServer. Detalhes: {e}")
16
- # A aplicação não pode rodar sem a lógica do servidor.
17
  raise
18
 
19
  # --- INICIALIZAÇÃO ---
20
- # Cria uma instância única e persistente do servidor.
21
- # A inicialização pesada (clonar repo, baixar modelos) acontece apenas uma vez, aqui.
22
  server = SeedVRServer()
23
 
24
  # --- FUNÇÕES AUXILIARES DA UI ---
@@ -43,8 +40,6 @@ def _extract_first_frame(video_path: str) -> Optional[str]:
43
  if not success:
44
  print(f"Erro: Não foi possível ler o primeiro frame de {video_path}")
45
  return None
46
-
47
- # Salva o frame no mesmo diretório do vídeo, com extensão .jpg
48
  image_path = Path(video_path).with_suffix(".jpg")
49
  cv2.imwrite(str(image_path), image)
50
  return str(image_path)
@@ -53,19 +48,13 @@ def _extract_first_frame(video_path: str) -> Optional[str]:
53
  return None
54
 
55
  def on_file_upload(file_obj):
56
- """
57
- Callback acionado quando o usuário faz o upload de um arquivo.
58
- Limpa saídas antigas e ajusta o slider de `sp_size` com base no tipo de arquivo.
59
- """
60
  if file_obj is None:
61
- # Limpa os resultados e o log se o arquivo for removido
62
  return gr.update(value=1), None, None, None, gr.update(value=None, visible=False)
63
 
64
  if _is_video(file_obj.name):
65
- # Para vídeos, sugere um valor padrão para multi-GPU e torna o slider interativo
66
  return gr.update(value=4, interactive=True), None, None, None, gr.update(value=None, visible=False)
67
  else:
68
- # Para imagens, trava o valor em 1, pois não há paralelismo de dados
69
  return gr.update(value=1, interactive=False), None, None, None, gr.update(value=None, visible=False)
70
 
71
  # --- FUNÇÃO PRINCIPAL DE INFERÊNCIA DA UI ---
@@ -78,11 +67,8 @@ def run_inference_ui(
78
  progress=gr.Progress(track_tqdm=True)
79
  ):
80
  """
81
- A função de callback principal do Gradio. Usa geradores (`yield`)
82
- para permitir atualizações da UI em tempo real durante a tarefa de longa duração.
83
  """
84
- # 1. Estado Inicial e Validação
85
- # No início, desabilita o botão, limpa resultados anteriores e mostra a janela de log.
86
  yield (
87
  gr.update(interactive=False, value="Processing... 🚀"),
88
  gr.update(value=None, visible=False),
@@ -101,23 +87,18 @@ def run_inference_ui(
101
  was_input_video = _is_video(input_file_path)
102
 
103
  try:
104
- # Define um callback que será chamado pelo backend para atualizar o progresso e o log
105
  def progress_callback_wrapper(step: float, desc: str):
106
  nonlocal last_log_message
107
- # Só adiciona ao log se a mensagem for nova, para evitar poluição visual
108
  if desc != last_log_message:
109
  log_buffer.append(f"{desc}\n")
110
  last_log_message = desc
111
- # Atualiza o objeto de progresso do Gradio
112
  progress(step, desc=desc)
113
 
114
- # 2. Executa a Inferência
115
- # Chama o método do servidor, passando o nosso callback.
116
  video_result_path = server.run_inference(
117
  file_path=input_file_path,
118
- seed=42, # Semente fixa conforme solicitado
119
  res_h=int(resolution),
120
- res_w=int(resolution), # Largura igual à altura
121
  sp_size=int(sp_size),
122
  fps=float(fps) if fps and fps > 0 else None,
123
  progress=progress_callback_wrapper,
@@ -126,17 +107,15 @@ def run_inference_ui(
126
  progress(1.0, desc="Complete!")
127
  log_buffer.append("✅ Inference complete! Processing final output...\n")
128
 
129
- # 3. Processa e Exibe os Resultados
130
  final_image, final_video = None, None
131
  if was_input_video:
132
  final_video = video_result_path
133
  log_buffer.append("✅ Video result is ready.\n")
134
- else: # Se a entrada foi uma imagem
135
  final_image = _extract_first_frame(video_result_path)
136
- final_video = video_result_path # Também disponibiliza o vídeo de 1 frame
137
  log_buffer.append("✅ Image result extracted from video.\n")
138
 
139
- # Yield final para mostrar os resultados e reabilitar o botão
140
  yield (
141
  gr.update(interactive=True, value="Restore Media"),
142
  gr.update(value=final_image, visible=final_image is not None),
@@ -152,7 +131,6 @@ def run_inference_ui(
152
  import traceback
153
  traceback.print_exc()
154
 
155
- # Yield para estado de erro: reabilita o botão e mostra o log com o erro
156
  yield (
157
  gr.update(interactive=True, value="Restore Media"),
158
  None, None, None,
@@ -160,9 +138,8 @@ def run_inference_ui(
160
  )
161
 
162
  # --- LAYOUT DA INTERFACE GRÁFICA (GRADIO) ---
163
-
164
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue"), title="SeedVR Media Restoration") as demo:
165
- # Cabeçalho
166
  gr.Markdown(
167
  """
168
  <div style='text-align: center; margin-bottom: 20px;'>
@@ -171,13 +148,10 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue"), title="SeedVR Media Res
171
  </div>
172
  """
173
  )
174
-
175
  with gr.Row():
176
- # --- Coluna da Esquerda: Entradas e Controles ---
177
  with gr.Column(scale=1):
178
  gr.Markdown("### 1. Upload Media")
179
  input_media = gr.File(label="Input File (Video or Image)", type="filepath", interactive=True)
180
-
181
  gr.Markdown("### 2. Configure Settings")
182
  with gr.Accordion("Generation Parameters", open=True):
183
  resolution_select = gr.Dropdown(
@@ -186,34 +160,22 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue"), title="SeedVR Media Res
186
  value="480",
187
  info="Sets the output height and width to this value."
188
  )
189
-
190
  sp_size_slider = gr.Slider(
191
  label="Frames per Batch (sp_size)",
192
  minimum=1, maximum=16, step=1, value=4,
193
  info="For multi-GPU videos. Automatically set to 1 for images."
194
  )
195
-
196
  fps_out = gr.Number(label="Output FPS (for Videos)", value=24, precision=0, info="Set to 0 to use the original FPS.")
197
-
198
  run_button = gr.Button("Restore Media", variant="primary", icon="✨")
199
-
200
- # --- Coluna da Direita: Resultados ---
201
  with gr.Column(scale=2):
202
  gr.Markdown("### 3. Results")
203
-
204
- # Janela de Log
205
  log_window = gr.Textbox(
206
- label="Inference Log 📝",
207
- lines=8, max_lines=15,
208
- interactive=False, visible=False, autoscroll=True
209
  )
210
-
211
- # Componentes de saída (começam invisíveis)
212
  output_image = gr.Image(label="Image Result", show_download_button=True, type="filepath", visible=False)
213
  output_video = gr.Video(label="Video Result", visible=False)
214
  output_download = gr.File(label="Download Full Result (Video)", visible=False)
215
-
216
- # --- Rodapé ---
217
  gr.Markdown(
218
  """
219
  ---
@@ -222,16 +184,12 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue"), title="SeedVR Media Res
222
  """
223
  )
224
 
225
- # --- Lógica de Eventos da UI ---
226
-
227
- # Ao fazer upload de um arquivo, ajusta o slider `sp_size` e limpa saídas antigas.
228
  input_media.upload(
229
  fn=on_file_upload,
230
  inputs=[input_media],
231
  outputs=[sp_size_slider, output_image, output_video, output_download, log_window]
232
  )
233
 
234
- # Ao clicar no botão, executa a função de inferência principal.
235
  run_button.click(
236
  fn=run_inference_ui,
237
  inputs=[input_media, resolution_select, sp_size_slider, fps_out],
@@ -239,8 +197,9 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue"), title="SeedVR Media Res
239
  )
240
 
241
  if __name__ == "__main__":
242
- # Garante que o start_method do multiprocessing seja 'spawn'
243
- # É uma boa prática definir isso no ponto de entrada principal.
 
244
  mp.set_start_method('spawn', force=True)
245
 
246
  demo.launch(
 
6
  from typing import Optional
7
  import gradio as gr
8
  import cv2
9
+ import multiprocessing as mp # <--- LINHA ADICIONADA AQUI
10
 
11
  # --- INTEGRAÇÃO COM A LÓGICA DO SERVIDOR ---
12
  try:
 
13
  from api.seedvr_server import SeedVRServer
14
  except ImportError as e:
15
  print(f"ERRO FATAL: Não foi possível importar o SeedVRServer. Detalhes: {e}")
 
16
  raise
17
 
18
  # --- INICIALIZAÇÃO ---
 
 
19
  server = SeedVRServer()
20
 
21
  # --- FUNÇÕES AUXILIARES DA UI ---
 
40
  if not success:
41
  print(f"Erro: Não foi possível ler o primeiro frame de {video_path}")
42
  return None
 
 
43
  image_path = Path(video_path).with_suffix(".jpg")
44
  cv2.imwrite(str(image_path), image)
45
  return str(image_path)
 
48
  return None
49
 
50
  def on_file_upload(file_obj):
51
+ """Callback acionado quando o usuário faz o upload de um arquivo."""
 
 
 
52
  if file_obj is None:
 
53
  return gr.update(value=1), None, None, None, gr.update(value=None, visible=False)
54
 
55
  if _is_video(file_obj.name):
 
56
  return gr.update(value=4, interactive=True), None, None, None, gr.update(value=None, visible=False)
57
  else:
 
58
  return gr.update(value=1, interactive=False), None, None, None, gr.update(value=None, visible=False)
59
 
60
  # --- FUNÇÃO PRINCIPAL DE INFERÊNCIA DA UI ---
 
67
  progress=gr.Progress(track_tqdm=True)
68
  ):
69
  """
70
+ A função de callback principal do Gradio.
 
71
  """
 
 
72
  yield (
73
  gr.update(interactive=False, value="Processing... 🚀"),
74
  gr.update(value=None, visible=False),
 
87
  was_input_video = _is_video(input_file_path)
88
 
89
  try:
 
90
  def progress_callback_wrapper(step: float, desc: str):
91
  nonlocal last_log_message
 
92
  if desc != last_log_message:
93
  log_buffer.append(f"{desc}\n")
94
  last_log_message = desc
 
95
  progress(step, desc=desc)
96
 
 
 
97
  video_result_path = server.run_inference(
98
  file_path=input_file_path,
99
+ seed=42,
100
  res_h=int(resolution),
101
+ res_w=int(resolution),
102
  sp_size=int(sp_size),
103
  fps=float(fps) if fps and fps > 0 else None,
104
  progress=progress_callback_wrapper,
 
107
  progress(1.0, desc="Complete!")
108
  log_buffer.append("✅ Inference complete! Processing final output...\n")
109
 
 
110
  final_image, final_video = None, None
111
  if was_input_video:
112
  final_video = video_result_path
113
  log_buffer.append("✅ Video result is ready.\n")
114
+ else:
115
  final_image = _extract_first_frame(video_result_path)
116
+ final_video = video_result_path
117
  log_buffer.append("✅ Image result extracted from video.\n")
118
 
 
119
  yield (
120
  gr.update(interactive=True, value="Restore Media"),
121
  gr.update(value=final_image, visible=final_image is not None),
 
131
  import traceback
132
  traceback.print_exc()
133
 
 
134
  yield (
135
  gr.update(interactive=True, value="Restore Media"),
136
  None, None, None,
 
138
  )
139
 
140
  # --- LAYOUT DA INTERFACE GRÁFICA (GRADIO) ---
 
141
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue"), title="SeedVR Media Restoration") as demo:
142
+ # ... (O layout da UI permanece exatamente o mesmo)
143
  gr.Markdown(
144
  """
145
  <div style='text-align: center; margin-bottom: 20px;'>
 
148
  </div>
149
  """
150
  )
 
151
  with gr.Row():
 
152
  with gr.Column(scale=1):
153
  gr.Markdown("### 1. Upload Media")
154
  input_media = gr.File(label="Input File (Video or Image)", type="filepath", interactive=True)
 
155
  gr.Markdown("### 2. Configure Settings")
156
  with gr.Accordion("Generation Parameters", open=True):
157
  resolution_select = gr.Dropdown(
 
160
  value="480",
161
  info="Sets the output height and width to this value."
162
  )
 
163
  sp_size_slider = gr.Slider(
164
  label="Frames per Batch (sp_size)",
165
  minimum=1, maximum=16, step=1, value=4,
166
  info="For multi-GPU videos. Automatically set to 1 for images."
167
  )
 
168
  fps_out = gr.Number(label="Output FPS (for Videos)", value=24, precision=0, info="Set to 0 to use the original FPS.")
 
169
  run_button = gr.Button("Restore Media", variant="primary", icon="✨")
 
 
170
  with gr.Column(scale=2):
171
  gr.Markdown("### 3. Results")
 
 
172
  log_window = gr.Textbox(
173
+ label="Inference Log 📝", lines=8, max_lines=15,
174
+ interactive=False, visible=False, autoscroll=True,
 
175
  )
 
 
176
  output_image = gr.Image(label="Image Result", show_download_button=True, type="filepath", visible=False)
177
  output_video = gr.Video(label="Video Result", visible=False)
178
  output_download = gr.File(label="Download Full Result (Video)", visible=False)
 
 
179
  gr.Markdown(
180
  """
181
  ---
 
184
  """
185
  )
186
 
 
 
 
187
  input_media.upload(
188
  fn=on_file_upload,
189
  inputs=[input_media],
190
  outputs=[sp_size_slider, output_image, output_video, output_download, log_window]
191
  )
192
 
 
193
  run_button.click(
194
  fn=run_inference_ui,
195
  inputs=[input_media, resolution_select, sp_size_slider, fps_out],
 
197
  )
198
 
199
  if __name__ == "__main__":
200
+ # Garante que o start_method do multiprocessing seja 'spawn', que é mais seguro
201
+ # e evita problemas de estado compartilhado entre processos.
202
+ # É uma boa prática definir isso no ponto de entrada principal da aplicação.
203
  mp.set_start_method('spawn', force=True)
204
 
205
  demo.launch(