EuuIia commited on
Commit
ef738e5
·
verified ·
1 Parent(s): d411dcc

Update app_seedvr.py

Browse files
Files changed (1) hide show
  1. app_seedvr.py +90 -55
app_seedvr.py CHANGED
@@ -7,27 +7,31 @@ from typing import Optional
7
  import gradio as gr
8
  import cv2
9
 
10
- # --- SERVER LOGIC INTEGRATION ---
11
  try:
 
12
  from api.seedvr_server import SeedVRServer
13
  except ImportError as e:
14
- print(f"FATAL ERROR: Could not import SeedVRServer. Details: {e}")
 
15
  raise
16
 
17
- # --- INITIALIZATION ---
 
 
18
  server = SeedVRServer()
19
 
20
- # --- HELPER FUNCTIONS ---
21
 
22
  def _is_video(path: str) -> bool:
23
- """Checks if a file path corresponds to a video type."""
24
  if not path: return False
25
  import mimetypes
26
  mime, _ = mimetypes.guess_type(path)
27
  return (mime or "").startswith("video")
28
 
29
  def _extract_first_frame(video_path: str) -> Optional[str]:
30
- """Extracts the first frame from a video and saves it as a JPG image."""
31
  if not video_path or not os.path.exists(video_path): return None
32
  try:
33
  vid_cap = cv2.VideoCapture(video_path)
@@ -35,23 +39,32 @@ def _extract_first_frame(video_path: str) -> Optional[str]:
35
  success, image = vid_cap.read()
36
  vid_cap.release()
37
  if not success: return None
 
 
38
  image_path = Path(video_path).with_suffix(".jpg")
39
  cv2.imwrite(str(image_path), image)
40
  return str(image_path)
41
  except Exception as e:
42
- print(f"Error extracting first frame: {e}")
43
  return None
44
 
45
  def on_file_upload(file_obj):
46
- """Callback triggered when a user uploads a file."""
 
 
 
47
  if file_obj is None:
48
- return 1
 
 
49
  if _is_video(file_obj.name):
50
- return gr.update(value=4, interactive=True)
 
51
  else:
52
- return gr.update(value=1, interactive=False)
 
53
 
54
- # --- CORE INFERENCE FUNCTION ---
55
 
56
  def run_inference_ui(
57
  input_file_path: Optional[str],
@@ -61,66 +74,63 @@ def run_inference_ui(
61
  progress=gr.Progress(track_tqdm=True)
62
  ):
63
  """
64
- The main callback function for Gradio, using generators (`yield`)
65
- for real-time UI updates.
66
  """
67
- # 1. Initial State & Validation
 
68
  yield (
69
  gr.update(interactive=False, value="Processing... 🚀"),
70
  gr.update(value=None, visible=False),
71
  gr.update(value=None, visible=False),
72
  gr.update(value=None, visible=False),
73
- gr.update(value="Waiting for logs...", visible=True)
74
  )
75
 
76
  if not input_file_path:
77
  gr.Warning("Please upload a media file first.")
78
- yield (
79
- gr.update(interactive=True, value="Restore Media"),
80
- None, None, None, gr.update(visible=False)
81
- )
82
  return
83
 
84
  log_buffer = ["▶ Starting inference process...\n"]
85
- yield gr.update(), None, None, None, ''.join(log_buffer)
86
-
87
- # CORREÇÃO APLICADA AQUI
88
- def progress_callback(step: float, desc: str):
89
- """A simple callback to append messages to our log buffer."""
90
- log_buffer.append(f"⏳ [{int(step*100)}%] {desc}\n")
91
- # A chamada correta para a API de progresso do Gradio
92
- progress(step, desc=desc)
93
-
94
  was_input_video = _is_video(input_file_path)
95
 
96
  try:
97
- # 2. Execute Inference
98
- progress_callback(0.1, "Calling backend engine...")
99
- yield gr.update(), None, None, None, ''.join(log_buffer)
100
-
 
 
 
 
 
 
101
  video_result_path = server.run_inference_direct(
102
  file_path=input_file_path,
103
- seed=42,
104
  res_h=int(resolution),
105
- res_w=int(resolution),
106
  sp_size=int(sp_size),
107
  fps=float(fps) if fps and fps > 0 else None,
108
- progress=progress,
109
  )
110
 
111
- progress_callback(1.0, "Inference complete! Processing final output...")
112
- yield gr.update(), None, None, None, ''.join(log_buffer)
113
 
114
- # 3. Process and Display Results
115
  final_image, final_video = None, None
116
  if was_input_video:
117
  final_video = video_result_path
118
- log_buffer.append(f"✅ Video result is ready.\n")
119
- else:
120
  final_image = _extract_first_frame(video_result_path)
121
- final_video = video_result_path
122
- log_buffer.append(f"✅ Image result extracted from video.\n")
123
 
 
124
  yield (
125
  gr.update(interactive=True, value="Restore Media"),
126
  gr.update(value=final_image, visible=final_image is not None),
@@ -132,20 +142,21 @@ def run_inference_ui(
132
  except Exception as e:
133
  error_message = f"❌ Inference failed: {e}"
134
  gr.Error(error_message)
135
- print(error_message)
136
  import traceback
137
  traceback.print_exc()
138
 
 
139
  yield (
140
  gr.update(interactive=True, value="Restore Media"),
141
  None, None, None,
142
- gr.update(value=f"{''.join(log_buffer)}\n{error_message}", visible=True)
143
  )
144
 
 
145
 
146
- # --- GRADIO UI LAYOUT ---
147
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue"), title="SeedVR Media Restoration") as demo:
148
- # Header
149
  gr.Markdown(
150
  """
151
  <div style='text-align: center; margin-bottom: 20px;'>
@@ -154,34 +165,50 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue"), title="SeedVR Media Res
154
  </div>
155
  """
156
  )
 
157
  with gr.Row():
 
158
  with gr.Column(scale=1):
159
  gr.Markdown("### 1. Upload Media")
160
- input_media = gr.File(label="Input File (Video or Image)", type="filepath")
 
 
161
  gr.Markdown("### 2. Configure Settings")
162
  with gr.Accordion("Generation Parameters", open=True):
163
  resolution_select = gr.Dropdown(
164
- label="Resolution (Short Edge)",
165
  choices=["480", "560", "720", "960", "1024"],
166
  value="480",
167
- info="The output height and width will be set to this value."
168
  )
 
169
  sp_size_slider = gr.Slider(
170
- label="Sequence Parallelism (sp_size)",
171
  minimum=1, maximum=16, step=1, value=4,
172
- info="For multi-GPU videos. This will be set to 1 for images."
173
  )
 
174
  fps_out = gr.Number(label="Output FPS (for Videos)", value=24, precision=0, info="Set to 0 to use the original FPS.")
 
175
  run_button = gr.Button("Restore Media", variant="primary", icon="✨")
 
 
176
  with gr.Column(scale=2):
177
  gr.Markdown("### 3. Results")
 
 
178
  log_window = gr.Textbox(
179
- label="Inference Log 📝", lines=8, max_lines=15,
180
- interactive=False, visible=False, autoscroll=True,
 
181
  )
 
 
182
  output_image = gr.Image(label="Image Result", show_download_button=True, type="filepath", visible=False)
183
  output_video = gr.Video(label="Video Result", visible=False)
184
  output_download = gr.File(label="Download Full Result (Video)", visible=False)
 
 
185
  gr.Markdown(
186
  """
187
  ---
@@ -190,8 +217,16 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue"), title="SeedVR Media Res
190
  """
191
  )
192
 
193
- input_media.upload(fn=on_file_upload, inputs=[input_media], outputs=[sp_size_slider])
 
 
 
 
 
 
 
194
 
 
195
  run_button.click(
196
  fn=run_inference_ui,
197
  inputs=[input_media, resolution_select, sp_size_slider, fps_out],
 
7
  import gradio as gr
8
  import cv2
9
 
10
+ # --- INTEGRAÇÃO COM A LÓGICA DO SERVIDOR ---
11
  try:
12
+ # Importa a classe SeedVRServer que agora atua como nossa biblioteca de inferência.
13
  from api.seedvr_server import SeedVRServer
14
  except ImportError as e:
15
+ print(f"ERRO FATAL: Não foi possível importar o SeedVRServer. Detalhes: {e}")
16
+ # A aplicação não pode rodar sem a lógica do servidor.
17
  raise
18
 
19
+ # --- INICIALIZAÇÃO ---
20
+ # Cria uma instância única e persistente do servidor.
21
+ # A inicialização (clonar repo, baixar modelos) acontece apenas uma vez, no início.
22
  server = SeedVRServer()
23
 
24
+ # --- FUNÇÕES AUXILIARES ---
25
 
26
  def _is_video(path: str) -> bool:
27
+ """Verifica se um caminho de arquivo corresponde a um tipo de vídeo."""
28
  if not path: return False
29
  import mimetypes
30
  mime, _ = mimetypes.guess_type(path)
31
  return (mime or "").startswith("video")
32
 
33
  def _extract_first_frame(video_path: str) -> Optional[str]:
34
+ """Extrai o primeiro frame de um vídeo e o salva como uma imagem JPG."""
35
  if not video_path or not os.path.exists(video_path): return None
36
  try:
37
  vid_cap = cv2.VideoCapture(video_path)
 
39
  success, image = vid_cap.read()
40
  vid_cap.release()
41
  if not success: return None
42
+
43
+ # Salva o frame no mesmo diretório do vídeo, com extensão .jpg
44
  image_path = Path(video_path).with_suffix(".jpg")
45
  cv2.imwrite(str(image_path), image)
46
  return str(image_path)
47
  except Exception as e:
48
+ print(f"Erro ao extrair o primeiro frame: {e}")
49
  return None
50
 
51
  def on_file_upload(file_obj):
52
+ """
53
+ Callback acionado quando o usuário faz o upload de um arquivo.
54
+ Verifica se o arquivo é um vídeo e sugere um `sp_size` apropriado.
55
+ """
56
  if file_obj is None:
57
+ # Limpa os resultados e o log se o arquivo for removido
58
+ return 1, None, None, None, None
59
+
60
  if _is_video(file_obj.name):
61
+ # Para vídeos, sugere um valor padrão para multi-GPU e torna o slider interativo
62
+ return gr.update(value=4, interactive=True), None, None, None, None
63
  else:
64
+ # Para imagens, trava o valor em 1
65
+ return gr.update(value=1, interactive=False), None, None, None, None
66
 
67
+ # --- FUNÇÃO PRINCIPAL DE INFERÊNCIA DA UI ---
68
 
69
  def run_inference_ui(
70
  input_file_path: Optional[str],
 
74
  progress=gr.Progress(track_tqdm=True)
75
  ):
76
  """
77
+ A função de callback principal do Gradio. Usa geradores (`yield`)
78
+ para permitir atualizações da UI em tempo real durante a tarefa de longa duração.
79
  """
80
+ # 1. Estado Inicial e Validação
81
+ # No início, desabilita o botão, limpa resultados anteriores e mostra a janela de log.
82
  yield (
83
  gr.update(interactive=False, value="Processing... 🚀"),
84
  gr.update(value=None, visible=False),
85
  gr.update(value=None, visible=False),
86
  gr.update(value=None, visible=False),
87
+ gr.update(value=" Starting inference process...\n", visible=True)
88
  )
89
 
90
  if not input_file_path:
91
  gr.Warning("Please upload a media file first.")
92
+ # Reabilita o botão e esconde os componentes de saída
93
+ yield (gr.update(interactive=True, value="Restore Media"), None, None, None, gr.update(visible=False))
 
 
94
  return
95
 
96
  log_buffer = ["▶ Starting inference process...\n"]
 
 
 
 
 
 
 
 
 
97
  was_input_video = _is_video(input_file_path)
98
 
99
  try:
100
+ # Define um callback que será chamado pelo backend para atualizar o progresso e o log
101
+ def progress_callback_wrapper(step: float, desc: str):
102
+ """ Wrapper para formatar logs e atualizar o progresso. """
103
+ # Adiciona a nova mensagem de log ao buffer
104
+ log_buffer.append(f"⏳ [{int(step*100)}%] {desc}\n")
105
+ # Atualiza o objeto de progresso do Gradio
106
+ progress(step, desc=desc)
107
+
108
+ # 2. Executa a Inferência
109
+ # Chama o método direto do servidor, passando o nosso callback.
110
  video_result_path = server.run_inference_direct(
111
  file_path=input_file_path,
112
+ seed=42, # Semente fixa conforme solicitado
113
  res_h=int(resolution),
114
+ res_w=int(resolution), # Largura igual à altura
115
  sp_size=int(sp_size),
116
  fps=float(fps) if fps and fps > 0 else None,
117
+ progress=progress_callback_wrapper, # Passa nossa função de callback
118
  )
119
 
120
+ progress(1.0, desc="Complete!")
121
+ log_buffer.append("✅ Inference complete! Processing final output...\n")
122
 
123
+ # 3. Processa e Exibe os Resultados
124
  final_image, final_video = None, None
125
  if was_input_video:
126
  final_video = video_result_path
127
+ log_buffer.append("✅ Video result is ready.\n")
128
+ else: # Se a entrada foi uma imagem
129
  final_image = _extract_first_frame(video_result_path)
130
+ final_video = video_result_path # Também disponibiliza o vídeo de 1 frame
131
+ log_buffer.append("✅ Image result extracted from video.\n")
132
 
133
+ # Yield final para mostrar os resultados e reabilitar o botão
134
  yield (
135
  gr.update(interactive=True, value="Restore Media"),
136
  gr.update(value=final_image, visible=final_image is not None),
 
142
  except Exception as e:
143
  error_message = f"❌ Inference failed: {e}"
144
  gr.Error(error_message)
145
+ log_buffer.append(f"\n{error_message}")
146
  import traceback
147
  traceback.print_exc()
148
 
149
+ # Yield para estado de erro: reabilita o botão e mostra o log com o erro
150
  yield (
151
  gr.update(interactive=True, value="Restore Media"),
152
  None, None, None,
153
+ gr.update(value=''.join(log_buffer), visible=True)
154
  )
155
 
156
+ # --- LAYOUT DA INTERFACE GRÁFICA (GRADIO) ---
157
 
 
158
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue"), title="SeedVR Media Restoration") as demo:
159
+ # Cabeçalho
160
  gr.Markdown(
161
  """
162
  <div style='text-align: center; margin-bottom: 20px;'>
 
165
  </div>
166
  """
167
  )
168
+
169
  with gr.Row():
170
+ # --- Coluna da Esquerda: Entradas e Controles ---
171
  with gr.Column(scale=1):
172
  gr.Markdown("### 1. Upload Media")
173
+ # Componente de upload agora mostra apenas o link, não a pré-visualização.
174
+ input_media = gr.File(label="Input File (Video or Image)", type="filepath", interactive=True)
175
+
176
  gr.Markdown("### 2. Configure Settings")
177
  with gr.Accordion("Generation Parameters", open=True):
178
  resolution_select = gr.Dropdown(
179
+ label="Resolution",
180
  choices=["480", "560", "720", "960", "1024"],
181
  value="480",
182
+ info="Sets the output height and width to this value."
183
  )
184
+
185
  sp_size_slider = gr.Slider(
186
+ label="Frames per Batch (sp_size)",
187
  minimum=1, maximum=16, step=1, value=4,
188
+ info="For multi-GPU videos. Automatically set to 1 for images."
189
  )
190
+
191
  fps_out = gr.Number(label="Output FPS (for Videos)", value=24, precision=0, info="Set to 0 to use the original FPS.")
192
+
193
  run_button = gr.Button("Restore Media", variant="primary", icon="✨")
194
+
195
+ # --- Coluna da Direita: Resultados ---
196
  with gr.Column(scale=2):
197
  gr.Markdown("### 3. Results")
198
+
199
+ # Janela de Log
200
  log_window = gr.Textbox(
201
+ label="Inference Log 📝",
202
+ lines=8, max_lines=15,
203
+ interactive=False, visible=False, autoscroll=True
204
  )
205
+
206
+ # Componentes de saída (começam invisíveis)
207
  output_image = gr.Image(label="Image Result", show_download_button=True, type="filepath", visible=False)
208
  output_video = gr.Video(label="Video Result", visible=False)
209
  output_download = gr.File(label="Download Full Result (Video)", visible=False)
210
+
211
+ # --- Rodapé ---
212
  gr.Markdown(
213
  """
214
  ---
 
217
  """
218
  )
219
 
220
+ # --- Lógica de Eventos da UI ---
221
+
222
+ # Ao fazer upload de um arquivo, ajusta o slider `sp_size` e limpa saídas antigas.
223
+ input_media.upload(
224
+ fn=on_file_upload,
225
+ inputs=[input_media],
226
+ outputs=[sp_size_slider, output_image, output_video, output_download, log_window]
227
+ )
228
 
229
+ # Ao clicar no botão, executa a função de inferência principal.
230
  run_button.click(
231
  fn=run_inference_ui,
232
  inputs=[input_media, resolution_select, sp_size_slider, fps_out],