EuuIia commited on
Commit
66bcb74
·
verified ·
1 Parent(s): 7dd33fe

Update app_seedvr.py

Browse files
Files changed (1) hide show
  1. app_seedvr.py +175 -68
app_seedvr.py CHANGED
@@ -1,31 +1,38 @@
1
  # app_seedvr.py
2
 
3
  import os
 
4
  from pathlib import Path
5
  from typing import Optional
6
  import gradio as gr
7
  import cv2
8
 
 
 
9
  try:
10
- # Importa a classe de servidor que agora é uma biblioteca local
11
  from api.seedvr_server import SeedVRServer
12
  except ImportError as e:
13
- print(f"ERRO FATAL: Não foi possível importar o SeedVRServer. Detalhes: {e}")
14
- # Se a importação falhar, a aplicação não pode continuar.
15
  raise
16
 
17
- # Cria uma instância única do servidor. A inicialização (clonar repo, baixar modelos) acontece aqui.
 
 
18
  server = SeedVRServer()
19
 
 
 
20
  def _is_video(path: str) -> bool:
21
- """Verifica se um caminho de arquivo corresponde a um tipo de vídeo."""
22
  if not path: return False
23
  import mimetypes
24
  mime, _ = mimetypes.guess_type(path)
25
  return (mime or "").startswith("video")
26
 
27
  def _extract_first_frame(video_path: str) -> Optional[str]:
28
- """Extrai o primeiro frame de um vídeo e o salva como uma imagem JPG."""
29
  if not video_path or not os.path.exists(video_path): return None
30
  try:
31
  vid_cap = cv2.VideoCapture(video_path)
@@ -34,106 +41,206 @@ def _extract_first_frame(video_path: str) -> Optional[str]:
34
  vid_cap.release()
35
  if not success: return None
36
 
37
- # Salva o frame no mesmo diretório do vídeo, com extensão .jpg
38
  image_path = Path(video_path).with_suffix(".jpg")
39
  cv2.imwrite(str(image_path), image)
40
  return str(image_path)
41
  except Exception as e:
42
- print(f"Erro ao extrair o primeiro frame: {e}")
43
  return None
44
 
45
- def ui_infer(
46
- input_path: Optional[str],
47
- seed: int, res_h: int, res_w: int,
48
- sp_size: int, fps: float,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  progress=gr.Progress(track_tqdm=True)
50
  ):
51
  """
52
- Função de callback principal do Gradio. Agora chama a lógica de inferência diretamente.
 
53
  """
54
- if not input_path:
55
- gr.Warning("Por favor, faça o upload de um arquivo.")
56
- return None, None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
- was_input_video = _is_video(input_path)
59
-
60
  try:
61
- # Desabilita o botão enquanto processa
62
- yield gr.update(interactive=False, value="Processando..."), None, None, None
 
63
 
64
- # Chama o método direto do servidor, passando o objeto de progresso do Gradio
65
  video_result_path = server.run_inference_direct(
66
- file_path=input_path,
67
- seed=int(seed),
68
- res_h=int(res_h),
69
- res_w=int(res_w),
70
  sp_size=int(sp_size),
71
  fps=float(fps) if fps and fps > 0 else None,
72
- progress=progress,
73
  )
74
 
75
- progress(1.0, desc="Concluído!")
76
-
 
 
77
  final_image, final_video = None, None
78
  if was_input_video:
79
  final_video = video_result_path
80
- else: # Se a entrada foi uma imagem
 
81
  final_image = _extract_first_frame(video_result_path)
82
- final_video = video_result_path
83
-
84
- # Retorna o resultado e reabilita o botão
 
85
  yield (
86
- gr.update(interactive=True, value="Restaurar Mídia"),
87
  gr.update(value=final_image, visible=final_image is not None),
88
  gr.update(value=final_video, visible=final_video is not None),
89
- gr.update(value=video_result_path, visible=video_result_path is not None)
 
90
  )
91
 
92
  except Exception as e:
93
- error_message = f"A inferência falhou: {e}"
94
  gr.Error(error_message)
95
  print(error_message)
96
  import traceback
97
  traceback.print_exc()
98
- # Limpa os resultados e reabilita o botão em caso de erro
99
- yield gr.update(interactive=True, value="Restaurar Mídia"), None, None, None
100
-
101
- # --- Construção da Interface Gráfica ---
102
- with gr.Blocks(title="SeedVR (Aduc-SDR)", theme=gr.themes.Soft()) as demo:
103
- gr.HTML("""
104
- <div style='text-align:center; margin-bottom: 20px;'>
105
- <h1>SeedVR - Restauração de Imagem e Vídeo</h1>
106
- <p>Implementação com backend Aduc-SDR</p>
 
 
 
 
 
 
 
 
 
107
  </div>
108
- """)
109
-
 
110
  with gr.Row():
 
111
  with gr.Column(scale=1):
112
- inp = gr.File(label="Arquivo de Entrada (Vídeo .mp4 ou Imagem)", type="filepath")
 
113
 
114
- with gr.Accordion("Parâmetros de Geração", open=True):
115
- with gr.Row():
116
- seed = gr.Number(label="Seed", value=42, precision=0)
117
- fps_out = gr.Number(label="FPS de Saída (para Vídeos)", value=24, precision=0, info="0 para usar o FPS original.")
118
- with gr.Row():
119
- res_h = gr.Number(label="Altura (Height)", value=720, precision=0)
120
- res_w = gr.Number(label="Largura (Width)", value=1280, precision=0)
121
-
122
- sp_size = gr.Slider(label="Paralelismo de Sequência (sp_size)", minimum=1, maximum=160, step=4, value=4, info="Para vídeos em multi-GPU. Use 1 para imagens.")
123
-
124
- run_button = gr.Button("Restaurar Mídia", variant="primary")
125
-
 
 
 
 
 
 
 
 
126
  with gr.Column(scale=2):
127
- gr.Markdown("### Resultado")
128
- out_image = gr.Image(label="Resultado (Imagem)", show_download_button=True, type="filepath", visible=True)
129
- out_video = gr.Video(label="Resultado (Vídeo)")
130
- out_download = gr.File(label="Baixar Resultado (Vídeo)")
131
-
132
- # A função click agora é um gerador.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  run_button.click(
134
- fn=ui_infer,
135
- inputs=[inp, seed, res_h, res_w, sp_size, fps_out],
136
- outputs=[run_button, out_image, out_video, out_download],
137
  )
138
 
139
  if __name__ == "__main__":
 
1
  # app_seedvr.py
2
 
3
  import os
4
+ import sys
5
  from pathlib import Path
6
  from typing import Optional
7
  import gradio as gr
8
  import cv2
9
 
10
+ # --- SERVER LOGIC INTEGRATION ---
11
+ # This section ensures we can import and use the SeedVR engine directly.
12
  try:
13
+ # We need the SeedVRServer class which handles the inference logic.
14
  from api.seedvr_server import SeedVRServer
15
  except ImportError as e:
16
+ print(f"FATAL ERROR: Could not import SeedVRServer. Details: {e}")
17
+ # The application cannot run without the server logic.
18
  raise
19
 
20
+ # --- INITIALIZATION ---
21
+ # Create a single, persistent instance of the server.
22
+ # This clones the repo and downloads models only once at startup.
23
  server = SeedVRServer()
24
 
25
+ # --- HELPER FUNCTIONS ---
26
+
27
  def _is_video(path: str) -> bool:
28
+ """Checks if a file path corresponds to a video type."""
29
  if not path: return False
30
  import mimetypes
31
  mime, _ = mimetypes.guess_type(path)
32
  return (mime or "").startswith("video")
33
 
34
  def _extract_first_frame(video_path: str) -> Optional[str]:
35
+ """Extracts the first frame from a video and saves it as a JPG image."""
36
  if not video_path or not os.path.exists(video_path): return None
37
  try:
38
  vid_cap = cv2.VideoCapture(video_path)
 
41
  vid_cap.release()
42
  if not success: return None
43
 
 
44
  image_path = Path(video_path).with_suffix(".jpg")
45
  cv2.imwrite(str(image_path), image)
46
  return str(image_path)
47
  except Exception as e:
48
+ print(f"Error extracting first frame: {e}")
49
  return None
50
 
51
+ def on_file_upload(file_obj):
52
+ """
53
+ Callback triggered when a user uploads a file.
54
+ It checks if the file is a video and suggests an appropriate `sp_size`.
55
+ """
56
+ if file_obj is None:
57
+ return 1 # Default to 1 if file is cleared
58
+
59
+ if _is_video(file_obj.name):
60
+ # For videos, suggest a default value suitable for multi-GPU
61
+ return gr.update(value=4, interactive=True)
62
+ else:
63
+ # For images, lock the value to 1
64
+ return gr.update(value=1, interactive=False)
65
+
66
+ # --- CORE INFERENCE FUNCTION ---
67
+
68
+ def run_inference_ui(
69
+ input_file_path: Optional[str],
70
+ resolution: str,
71
+ sp_size: int,
72
+ fps: float,
73
  progress=gr.Progress(track_tqdm=True)
74
  ):
75
  """
76
+ The main callback function for Gradio. This is a generator (`yield`)
77
+ to allow for real-time UI updates during the long-running task.
78
  """
79
+ # 1. Initial State & Validation
80
+ # On start, disable the button, clear previous results, and make the log visible.
81
+ yield (
82
+ gr.update(interactive=False, value="Processing... 🚀"),
83
+ gr.update(value=None, visible=False),
84
+ gr.update(value=None, visible=False),
85
+ gr.update(value=None, visible=False),
86
+ gr.update(value="Waiting for logs...", visible=True)
87
+ )
88
+
89
+ if not input_file_path:
90
+ gr.Warning("Please upload a media file first.")
91
+ # Re-enable button and hide outputs
92
+ yield (
93
+ gr.update(interactive=True, value="Restore Media"),
94
+ None, None, None, gr.update(visible=False)
95
+ )
96
+ return
97
+
98
+ # Use a simple list to act as a log buffer that can be updated by a callback
99
+ log_buffer = ["▶ Starting inference process...\n"]
100
+ yield gr.update(), None, None, None, ''.join(log_buffer)
101
+
102
+ def progress_callback(step: float, desc: str):
103
+ """A simple callback to append messages to our log buffer."""
104
+ # This function can be passed to the backend if it supports it.
105
+ # For now, we'll call it manually from this UI function.
106
+ log_buffer.append(f"⏳ [{int(step*100)}%] {desc}\n")
107
+ progress.update(amount=step, desc=desc)
108
+
109
+ was_input_video = _is_video(input_file_path)
110
 
 
 
111
  try:
112
+ # 2. Execute Inference
113
+ progress_callback(0.1, "Calling backend engine...")
114
+ yield gr.update(), None, None, None, ''.join(log_buffer)
115
 
116
+ # Call the server's direct inference method. This is a blocking call.
117
  video_result_path = server.run_inference_direct(
118
+ file_path=input_file_path,
119
+ seed=42, # Using a fixed seed as requested
120
+ res_h=int(resolution),
121
+ res_w=int(resolution), # Set width equal to height
122
  sp_size=int(sp_size),
123
  fps=float(fps) if fps and fps > 0 else None,
124
+ progress=progress, # Pass the Gradio progress object
125
  )
126
 
127
+ progress_callback(1.0, "Inference complete! Processing final output...")
128
+ yield gr.update(), None, None, None, ''.join(log_buffer)
129
+
130
+ # 3. Process and Display Results
131
  final_image, final_video = None, None
132
  if was_input_video:
133
  final_video = video_result_path
134
+ log_buffer.append(f"✅ Video result is ready.\n")
135
+ else: # If input was an image
136
  final_image = _extract_first_frame(video_result_path)
137
+ final_video = video_result_path # Also provide the 1-frame video
138
+ log_buffer.append(f"✅ Image result extracted from video.\n")
139
+
140
+ # Final yield to show the results and re-enable the button
141
  yield (
142
+ gr.update(interactive=True, value="Restore Media"),
143
  gr.update(value=final_image, visible=final_image is not None),
144
  gr.update(value=final_video, visible=final_video is not None),
145
+ gr.update(value=video_result_path, visible=video_result_path is not None),
146
+ ''.join(log_buffer)
147
  )
148
 
149
  except Exception as e:
150
+ error_message = f" Inference failed: {e}"
151
  gr.Error(error_message)
152
  print(error_message)
153
  import traceback
154
  traceback.print_exc()
155
+
156
+ # Yield an error state and re-enable the button
157
+ yield (
158
+ gr.update(interactive=True, value="Restore Media"),
159
+ None, None, None,
160
+ gr.update(value=f"{''.join(log_buffer)}\n{error_message}", visible=True)
161
+ )
162
+
163
+
164
+ # --- GRADIO UI LAYOUT ---
165
+
166
+ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue"), title="SeedVR Media Restoration") as demo:
167
+ # Header
168
+ gr.Markdown(
169
+ """
170
+ <div style='text-align: center; margin-bottom: 20px;'>
171
+ <h1>📸 SeedVR - Image & Video Restoration 🚀</h1>
172
+ <p>High-quality media upscaling powered by SeedVR-3B. Upload your file and see the magic.</p>
173
  </div>
174
+ """
175
+ )
176
+
177
  with gr.Row():
178
+ # --- Left Column: Inputs & Controls ---
179
  with gr.Column(scale=1):
180
+ gr.Markdown("### 1. Upload Media")
181
+ input_media = gr.File(label="Input File (Video or Image)", type="filepath")
182
 
183
+ gr.Markdown("### 2. Configure Settings")
184
+ with gr.Accordion("Generation Parameters", open=True):
185
+ resolution_select = gr.Dropdown(
186
+ label="Resolution (Short Edge)",
187
+ choices=["480", "560", "720", "960", "1024"],
188
+ value="480",
189
+ info="The output height and width will be set to this value."
190
+ )
191
+
192
+ sp_size_slider = gr.Slider(
193
+ label="Sequence Parallelism (sp_size)",
194
+ minimum=1, maximum=16, step=1, value=4,
195
+ info="For multi-GPU videos. This will be set to 1 for images."
196
+ )
197
+
198
+ fps_out = gr.Number(label="Output FPS (for Videos)", value=24, precision=0, info="Set to 0 to use the original FPS.")
199
+
200
+ run_button = gr.Button("Restore Media", variant="primary", icon="✨")
201
+
202
+ # --- Right Column: Outputs ---
203
  with gr.Column(scale=2):
204
+ gr.Markdown("### 3. Results")
205
+
206
+ # Log window
207
+ log_window = gr.Textbox(
208
+ label="Inference Log 📝",
209
+ lines=8,
210
+ max_lines=15,
211
+ interactive=False,
212
+ visible=False, # Starts hidden
213
+ autoscroll=True,
214
+ )
215
+
216
+ # Output components start hidden and are made visible upon completion
217
+ output_image = gr.Image(label="Image Result", show_download_button=True, type="filepath", visible=False)
218
+ output_video = gr.Video(label="Video Result", visible=False)
219
+ output_download = gr.File(label="Download Full Result (Video)", visible=False)
220
+
221
+ # --- Footer ---
222
+ gr.Markdown(
223
+ """
224
+ ---
225
+ *Space and Docker were developed by Carlex.*
226
+ *Contact: Email: Carlex22@gmail.com | GitHub: [carlex22](https://github.com/carlex22)*
227
+ """
228
+ )
229
+
230
+ # --- Event Handlers ---
231
+
232
+ # When a file is uploaded, automatically adjust the sp_size slider
233
+ input_media.upload(
234
+ fn=on_file_upload,
235
+ inputs=[input_media],
236
+ outputs=[sp_size_slider]
237
+ )
238
+
239
+ # When the "Restore Media" button is clicked, run the main inference function
240
  run_button.click(
241
+ fn=run_inference_ui,
242
+ inputs=[input_media, resolution_select, sp_size_slider, fps_out],
243
+ outputs=[run_button, output_image, output_video, output_download, log_window],
244
  )
245
 
246
  if __name__ == "__main__":