Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| import tempfile | |
| from PIL import Image | |
| import spaces | |
| from tqdm.auto import tqdm | |
| from diffusers import DDIMScheduler, AutoencoderKL | |
| from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection | |
| from GeoWizard.geowizard.models.unet_2d_condition import UNet2DConditionModel | |
| from GeoWizard.geowizard.models.geowizard_pipeline import DepthNormalEstimationPipeline | |
| # Device setup | |
| DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| CHECKPOINT_PATH = "GonzaloMG/geowizard-e2e-ft" | |
| # Load pretrained components | |
| vae = AutoencoderKL.from_pretrained(CHECKPOINT_PATH, subfolder='vae') | |
| scheduler = DDIMScheduler.from_pretrained(CHECKPOINT_PATH, timestep_spacing="trailing", subfolder='scheduler') | |
| image_encoder = CLIPVisionModelWithProjection.from_pretrained(CHECKPOINT_PATH, subfolder="image_encoder") | |
| feature_extractor = CLIPImageProcessor.from_pretrained(CHECKPOINT_PATH, subfolder="feature_extractor") | |
| unet = UNet2DConditionModel.from_pretrained(CHECKPOINT_PATH, subfolder="unet") | |
| # Instantiate pipeline | |
| pipe = DepthNormalEstimationPipeline( | |
| vae=vae, | |
| image_encoder=image_encoder, | |
| feature_extractor=feature_extractor, | |
| unet=unet, | |
| scheduler=scheduler | |
| ).to(DEVICE) | |
| pipe.unet.eval() | |
| # UI texts | |
| title = "# End-to-End Fine-Tuned GeoWizard Video" | |
| description = ( | |
| """ | |
| Please refer to our [paper](https://arxiv.org/abs/2409.11355) and | |
| [GitHub](https://vision.rwth-aachen.de/diffusion-e2e-ft) for more details. | |
| """ | |
| ) | |
| def predict(image: Image.Image, processing_res_choice: int): | |
| """ | |
| Single-frame prediction wrapped for GPU execution. | |
| Returns a DepthNormalPipelineOutput with attribute normal_colored. | |
| """ | |
| with torch.no_grad(): | |
| return pipe( | |
| image, | |
| denoising_steps=1, | |
| ensemble_size=1, | |
| noise="zeros", | |
| processing_res=processing_res_choice, | |
| match_input_res=True | |
| ) | |
| def on_submit_video(video_path: str, processing_res_choice: int): | |
| """ | |
| Processes each frame of the input video, generating a normal map video. | |
| """ | |
| if video_path is None: | |
| print("No video uploaded.") | |
| return None | |
| cap = cv2.VideoCapture(video_path) | |
| fps = cap.get(cv2.CAP_PROP_FPS) or 30 | |
| width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| # Temporary output file for normals video | |
| tmp_normal = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) | |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
| out_normal = cv2.VideoWriter(tmp_normal.name, fourcc, fps, (width, height)) | |
| # Process each frame | |
| for _ in tqdm(range(frame_count), desc="Processing frames"): | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| # Convert frame to PIL image | |
| rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| pil_image = Image.fromarray(rgb) | |
| # Predict normals | |
| result = predict(pil_image, processing_res_choice) | |
| normal_colored = result.normal_colored | |
| # Write normal frame | |
| normal_frame = np.array(normal_colored) | |
| normal_bgr = cv2.cvtColor(normal_frame, cv2.COLOR_RGB2BGR) | |
| out_normal.write(normal_bgr) | |
| # Release resources | |
| cap.release() | |
| out_normal.release() | |
| # Return video path for download | |
| return tmp_normal.name | |
| # Build Gradio interface | |
| with gr.Blocks() as demo: | |
| gr.Markdown(title) | |
| gr.Markdown(description) | |
| gr.Markdown("### Normals Prediction on Video") | |
| with gr.Row(): | |
| input_video = gr.Video(label="Input Video", elem_id='video-display-input') | |
| with gr.Column(): | |
| processing_res_choice = gr.Radio( | |
| [ | |
| ("Recommended (768)", 768), | |
| ("Native (original)", 0), | |
| ], | |
| label="Processing resolution", | |
| value=768, | |
| ) | |
| submit = gr.Button(value="Compute Normals") | |
| with gr.Row(): | |
| output_normal_video = gr.Video(label="Normal Video", elem_id='download') | |
| submit.click( | |
| fn=on_submit_video, | |
| inputs=[input_video, processing_res_choice], | |
| outputs=[output_normal_video] | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue().launch(share=True) | |