|
|
|
|
|
|
|
|
import os |
|
|
import sys |
|
|
from pathlib import Path |
|
|
from typing import Optional |
|
|
import gradio as gr |
|
|
import cv2 |
|
|
|
|
|
|
|
|
try: |
|
|
from api.seedvr_server import SeedVRServer |
|
|
except ImportError as e: |
|
|
print(f"FATAL ERROR: Could not import SeedVRServer. Details: {e}") |
|
|
raise |
|
|
|
|
|
|
|
|
server = SeedVRServer() |
|
|
|
|
|
|
|
|
|
|
|
def _is_video(path: str) -> bool: |
|
|
"""Checks if a file path corresponds to a video type.""" |
|
|
if not path: return False |
|
|
import mimetypes |
|
|
mime, _ = mimetypes.guess_type(path) |
|
|
return (mime or "").startswith("video") |
|
|
|
|
|
def _extract_first_frame(video_path: str) -> Optional[str]: |
|
|
"""Extracts the first frame from a video and saves it as a JPG image.""" |
|
|
if not video_path or not os.path.exists(video_path): return None |
|
|
try: |
|
|
vid_cap = cv2.VideoCapture(video_path) |
|
|
if not vid_cap.isOpened(): return None |
|
|
success, image = vid_cap.read() |
|
|
vid_cap.release() |
|
|
if not success: return None |
|
|
image_path = Path(video_path).with_suffix(".jpg") |
|
|
cv2.imwrite(str(image_path), image) |
|
|
return str(image_path) |
|
|
except Exception as e: |
|
|
print(f"Error extracting first frame: {e}") |
|
|
return None |
|
|
|
|
|
def on_file_upload(file_obj): |
|
|
"""Callback triggered when a user uploads a file.""" |
|
|
if file_obj is None: |
|
|
return 1 |
|
|
if _is_video(file_obj.name): |
|
|
return gr.update(value=4, interactive=True) |
|
|
else: |
|
|
return gr.update(value=1, interactive=False) |
|
|
|
|
|
|
|
|
|
|
|
def run_inference_ui( |
|
|
input_file_path: Optional[str], |
|
|
resolution: str, |
|
|
sp_size: int, |
|
|
fps: float, |
|
|
progress=gr.Progress(track_tqdm=True) |
|
|
): |
|
|
""" |
|
|
The main callback function for Gradio, using generators (`yield`) |
|
|
for real-time UI updates. |
|
|
""" |
|
|
|
|
|
yield ( |
|
|
gr.update(interactive=False, value="Processing... 🚀"), |
|
|
gr.update(value=None, visible=False), |
|
|
gr.update(value=None, visible=False), |
|
|
gr.update(value=None, visible=False), |
|
|
gr.update(value="Waiting for logs...", visible=True) |
|
|
) |
|
|
|
|
|
if not input_file_path: |
|
|
gr.Warning("Please upload a media file first.") |
|
|
yield ( |
|
|
gr.update(interactive=True, value="Restore Media"), |
|
|
None, None, None, gr.update(visible=False) |
|
|
) |
|
|
return |
|
|
|
|
|
log_buffer = ["▶ Starting inference process...\n"] |
|
|
yield gr.update(), None, None, None, ''.join(log_buffer) |
|
|
|
|
|
|
|
|
def progress_callback(step: float, desc: str): |
|
|
"""A simple callback to append messages to our log buffer.""" |
|
|
log_buffer.append(f"⏳ [{int(step*100)}%] {desc}\n") |
|
|
|
|
|
progress(step, desc=desc) |
|
|
|
|
|
was_input_video = _is_video(input_file_path) |
|
|
|
|
|
try: |
|
|
|
|
|
progress_callback(0.1, "Calling backend engine...") |
|
|
yield gr.update(), None, None, None, ''.join(log_buffer) |
|
|
|
|
|
video_result_path = server.run_inference_direct( |
|
|
file_path=input_file_path, |
|
|
seed=42, |
|
|
res_h=int(resolution), |
|
|
res_w=int(resolution), |
|
|
sp_size=int(sp_size), |
|
|
fps=float(fps) if fps and fps > 0 else None, |
|
|
progress=progress, |
|
|
) |
|
|
|
|
|
progress_callback(1.0, "Inference complete! Processing final output...") |
|
|
yield gr.update(), None, None, None, ''.join(log_buffer) |
|
|
|
|
|
|
|
|
final_image, final_video = None, None |
|
|
if was_input_video: |
|
|
final_video = video_result_path |
|
|
log_buffer.append(f"✅ Video result is ready.\n") |
|
|
else: |
|
|
final_image = _extract_first_frame(video_result_path) |
|
|
final_video = video_result_path |
|
|
log_buffer.append(f"✅ Image result extracted from video.\n") |
|
|
|
|
|
yield ( |
|
|
gr.update(interactive=True, value="Restore Media"), |
|
|
gr.update(value=final_image, visible=final_image is not None), |
|
|
gr.update(value=final_video, visible=final_video is not None), |
|
|
gr.update(value=video_result_path, visible=video_result_path is not None), |
|
|
''.join(log_buffer) |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
error_message = f"❌ Inference failed: {e}" |
|
|
gr.Error(error_message) |
|
|
print(error_message) |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
|
|
|
yield ( |
|
|
gr.update(interactive=True, value="Restore Media"), |
|
|
None, None, None, |
|
|
gr.update(value=f"{''.join(log_buffer)}\n{error_message}", visible=True) |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue"), title="SeedVR Media Restoration") as demo: |
|
|
|
|
|
gr.Markdown( |
|
|
""" |
|
|
<div style='text-align: center; margin-bottom: 20px;'> |
|
|
<h1>📸 SeedVR - Image & Video Restoration 🚀</h1> |
|
|
<p>High-quality media upscaling powered by SeedVR-3B. Upload your file and see the magic.</p> |
|
|
</div> |
|
|
""" |
|
|
) |
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
gr.Markdown("### 1. Upload Media") |
|
|
input_media = gr.File(label="Input File (Video or Image)", type="filepath") |
|
|
gr.Markdown("### 2. Configure Settings") |
|
|
with gr.Accordion("Generation Parameters", open=True): |
|
|
resolution_select = gr.Dropdown( |
|
|
label="Resolution (Short Edge)", |
|
|
choices=["480", "560", "720", "960", "1024"], |
|
|
value="480", |
|
|
info="The output height and width will be set to this value." |
|
|
) |
|
|
sp_size_slider = gr.Slider( |
|
|
label="Sequence Parallelism (sp_size)", |
|
|
minimum=1, maximum=16, step=1, value=4, |
|
|
info="For multi-GPU videos. This will be set to 1 for images." |
|
|
) |
|
|
fps_out = gr.Number(label="Output FPS (for Videos)", value=24, precision=0, info="Set to 0 to use the original FPS.") |
|
|
run_button = gr.Button("Restore Media", variant="primary", icon="✨") |
|
|
with gr.Column(scale=2): |
|
|
gr.Markdown("### 3. Results") |
|
|
log_window = gr.Textbox( |
|
|
label="Inference Log 📝", lines=8, max_lines=15, |
|
|
interactive=False, visible=False, autoscroll=True, |
|
|
) |
|
|
output_image = gr.Image(label="Image Result", show_download_button=True, type="filepath", visible=False) |
|
|
output_video = gr.Video(label="Video Result", visible=False) |
|
|
output_download = gr.File(label="Download Full Result (Video)", visible=False) |
|
|
gr.Markdown( |
|
|
""" |
|
|
--- |
|
|
*Space and Docker were developed by Carlex.* |
|
|
*Contact: Email: Carlex22@gmail.com | GitHub: [carlex22](https://github.com/carlex22)* |
|
|
""" |
|
|
) |
|
|
|
|
|
input_media.upload(fn=on_file_upload, inputs=[input_media], outputs=[sp_size_slider]) |
|
|
|
|
|
run_button.click( |
|
|
fn=run_inference_ui, |
|
|
inputs=[input_media, resolution_select, sp_size_slider, fps_out], |
|
|
outputs=[run_button, output_image, output_video, output_download, log_window], |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch( |
|
|
server_name=os.getenv("GRADIO_SERVER_NAME", "0.0.0.0"), |
|
|
server_port=int(os.getenv("GRADIO_SERVER_PORT", "7860")), |
|
|
show_error=True |
|
|
) |