EuuIia commited on
Commit
edd8fe7
·
verified ·
1 Parent(s): 2991b2f

Update app_seedvr.py

Browse files
Files changed (1) hide show
  1. app_seedvr.py +15 -13
app_seedvr.py CHANGED
@@ -9,7 +9,7 @@ import cv2
9
 
10
  # --- INTEGRAÇÃO COM A LÓGICA DO SERVIDOR ---
11
  try:
12
- # Importa a classe SeedVRServer que agora atua como nossa biblioteca de inferência.
13
  from api.seedvr_server import SeedVRServer
14
  except ImportError as e:
15
  print(f"ERRO FATAL: Não foi possível importar o SeedVRServer. Detalhes: {e}")
@@ -18,10 +18,10 @@ except ImportError as e:
18
 
19
  # --- INICIALIZAÇÃO ---
20
  # Cria uma instância única e persistente do servidor.
21
- # A inicialização (clonar repo, baixar modelos) acontece apenas uma vez, no início.
22
  server = SeedVRServer()
23
 
24
- # --- FUNÇÕES AUXILIARES ---
25
 
26
  def _is_video(path: str) -> bool:
27
  """Verifica se um caminho de arquivo corresponde a um tipo de vídeo."""
@@ -55,7 +55,7 @@ def _extract_first_frame(video_path: str) -> Optional[str]:
55
  def on_file_upload(file_obj):
56
  """
57
  Callback acionado quando o usuário faz o upload de um arquivo.
58
- Verifica se o arquivo é um vídeo e sugere um `sp_size` apropriado.
59
  """
60
  if file_obj is None:
61
  # Limpa os resultados e o log se o arquivo for removido
@@ -63,9 +63,9 @@ def on_file_upload(file_obj):
63
 
64
  if _is_video(file_obj.name):
65
  # Para vídeos, sugere um valor padrão para multi-GPU e torna o slider interativo
66
- return gr.update(value=8, interactive=True), None, None, None, gr.update(value=None, visible=False)
67
  else:
68
- # Para imagens, trava o valor em 1
69
  return gr.update(value=1, interactive=False), None, None, None, gr.update(value=None, visible=False)
70
 
71
  # --- FUNÇÃO PRINCIPAL DE INFERÊNCIA DA UI ---
@@ -93,7 +93,6 @@ def run_inference_ui(
93
 
94
  if not input_file_path:
95
  gr.Warning("Please upload a media file first.")
96
- # Reabilita o botão e esconde os componentes de saída
97
  yield (gr.update(interactive=True, value="Restore Media"), None, None, None, gr.update(visible=False))
98
  return
99
 
@@ -104,7 +103,6 @@ def run_inference_ui(
104
  try:
105
  # Define um callback que será chamado pelo backend para atualizar o progresso e o log
106
  def progress_callback_wrapper(step: float, desc: str):
107
- """ Wrapper para formatar logs e atualizar o progresso. """
108
  nonlocal last_log_message
109
  # Só adiciona ao log se a mensagem for nova, para evitar poluição visual
110
  if desc != last_log_message:
@@ -114,15 +112,15 @@ def run_inference_ui(
114
  progress(step, desc=desc)
115
 
116
  # 2. Executa a Inferência
117
- # Chama o método direto do servidor, passando o nosso callback.
118
- video_result_path = server.run_inference_direct(
119
  file_path=input_file_path,
120
  seed=42, # Semente fixa conforme solicitado
121
  res_h=int(resolution),
122
  res_w=int(resolution), # Largura igual à altura
123
  sp_size=int(sp_size),
124
  fps=float(fps) if fps and fps > 0 else None,
125
- progress=progress_callback_wrapper, # Passa nossa função de callback
126
  )
127
 
128
  progress(1.0, desc="Complete!")
@@ -184,14 +182,14 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue"), title="SeedVR Media Res
184
  with gr.Accordion("Generation Parameters", open=True):
185
  resolution_select = gr.Dropdown(
186
  label="Resolution",
187
- choices=["480", "560", "720", "960", "1024", "2048"],
188
  value="480",
189
  info="Sets the output height and width to this value."
190
  )
191
 
192
  sp_size_slider = gr.Slider(
193
  label="Frames per Batch (sp_size)",
194
- minimum=1, maximum=16, step=4, value=8,
195
  info="For multi-GPU videos. Automatically set to 1 for images."
196
  )
197
 
@@ -241,6 +239,10 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue"), title="SeedVR Media Res
241
  )
242
 
243
  if __name__ == "__main__":
 
 
 
 
244
  demo.launch(
245
  server_name=os.getenv("GRADIO_SERVER_NAME", "0.0.0.0"),
246
  server_port=int(os.getenv("GRADIO_SERVER_PORT", "7860")),
 
9
 
10
  # --- INTEGRAÇÃO COM A LÓGICA DO SERVIDOR ---
11
  try:
12
+ # Importa a classe SeedVRServer que agora contém toda a lógica de inferência.
13
  from api.seedvr_server import SeedVRServer
14
  except ImportError as e:
15
  print(f"ERRO FATAL: Não foi possível importar o SeedVRServer. Detalhes: {e}")
 
18
 
19
  # --- INICIALIZAÇÃO ---
20
  # Cria uma instância única e persistente do servidor.
21
+ # A inicialização pesada (clonar repo, baixar modelos) acontece apenas uma vez, aqui.
22
  server = SeedVRServer()
23
 
24
+ # --- FUNÇÕES AUXILIARES DA UI ---
25
 
26
  def _is_video(path: str) -> bool:
27
  """Verifica se um caminho de arquivo corresponde a um tipo de vídeo."""
 
55
  def on_file_upload(file_obj):
56
  """
57
  Callback acionado quando o usuário faz o upload de um arquivo.
58
+ Limpa saídas antigas e ajusta o slider de `sp_size` com base no tipo de arquivo.
59
  """
60
  if file_obj is None:
61
  # Limpa os resultados e o log se o arquivo for removido
 
63
 
64
  if _is_video(file_obj.name):
65
  # Para vídeos, sugere um valor padrão para multi-GPU e torna o slider interativo
66
+ return gr.update(value=4, interactive=True), None, None, None, gr.update(value=None, visible=False)
67
  else:
68
+ # Para imagens, trava o valor em 1, pois não há paralelismo de dados
69
  return gr.update(value=1, interactive=False), None, None, None, gr.update(value=None, visible=False)
70
 
71
  # --- FUNÇÃO PRINCIPAL DE INFERÊNCIA DA UI ---
 
93
 
94
  if not input_file_path:
95
  gr.Warning("Please upload a media file first.")
 
96
  yield (gr.update(interactive=True, value="Restore Media"), None, None, None, gr.update(visible=False))
97
  return
98
 
 
103
  try:
104
  # Define um callback que será chamado pelo backend para atualizar o progresso e o log
105
  def progress_callback_wrapper(step: float, desc: str):
 
106
  nonlocal last_log_message
107
  # Só adiciona ao log se a mensagem for nova, para evitar poluição visual
108
  if desc != last_log_message:
 
112
  progress(step, desc=desc)
113
 
114
  # 2. Executa a Inferência
115
+ # Chama o método do servidor, passando o nosso callback.
116
+ video_result_path = server.run_inference(
117
  file_path=input_file_path,
118
  seed=42, # Semente fixa conforme solicitado
119
  res_h=int(resolution),
120
  res_w=int(resolution), # Largura igual à altura
121
  sp_size=int(sp_size),
122
  fps=float(fps) if fps and fps > 0 else None,
123
+ progress=progress_callback_wrapper,
124
  )
125
 
126
  progress(1.0, desc="Complete!")
 
182
  with gr.Accordion("Generation Parameters", open=True):
183
  resolution_select = gr.Dropdown(
184
  label="Resolution",
185
+ choices=["480", "560", "720", "960", "1024"],
186
  value="480",
187
  info="Sets the output height and width to this value."
188
  )
189
 
190
  sp_size_slider = gr.Slider(
191
  label="Frames per Batch (sp_size)",
192
+ minimum=1, maximum=16, step=1, value=4,
193
  info="For multi-GPU videos. Automatically set to 1 for images."
194
  )
195
 
 
239
  )
240
 
241
  if __name__ == "__main__":
242
+ # Garante que o start_method do multiprocessing seja 'spawn'
243
+ # É uma boa prática definir isso no ponto de entrada principal.
244
+ mp.set_start_method('spawn', force=True)
245
+
246
  demo.launch(
247
  server_name=os.getenv("GRADIO_SERVER_NAME", "0.0.0.0"),
248
  server_port=int(os.getenv("GRADIO_SERVER_PORT", "7860")),