Spaces:
Running
Running
| import gradio as gr | |
| import cv2 | |
| import numpy as np | |
| import os | |
| from PIL import Image | |
| from stitching import Stitcher | |
| import torch | |
| from diffusers import StableDiffusionInpaintPipeline | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Initialize models lazily or globally depending on the space environment | |
| # Using a lightweight inpainting model | |
| inpaint_pipe = None | |
| def load_models(): | |
| global inpaint_pipe | |
| if inpaint_pipe is None: | |
| model_id = "stabilityai/stable-diffusion-2-inpainting" | |
| inpaint_pipe = StableDiffusionInpaintPipeline.from_pretrained( | |
| model_id, | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| ) | |
| inpaint_pipe = inpaint_pipe.to(device) | |
| def stitch_images(image_paths): | |
| """Stitches multiple images into a single panorama.""" | |
| stitcher = Stitcher(detector="sift", confidence_threshold=0.01) | |
| try: | |
| images = [] | |
| for path in image_paths: | |
| img = cv2.imread(path) | |
| if img is None: | |
| return None, f"Could not read image at {path}. Ensure it is a valid image file." | |
| # Optionally resize for performance/stability if images are extremely large | |
| # height, width = img.shape[:2] | |
| # max_dim = 2000 | |
| # if max(height, width) > max_dim: | |
| # scale = max_dim / max(height, width) | |
| # img = cv2.resize(img, (int(width * scale), int(height * scale))) | |
| images.append(img) | |
| panorama = stitcher.stitch(images) | |
| return panorama, None | |
| except Exception as e: | |
| print(f"Stitching failed: {e}") | |
| return None, str(e) | |
| def outpaint_panorama(panorama_np): | |
| """Uses stable diffusion inpainting to fill black borders and missing sky/ground.""" | |
| load_models() | |
| # Convert OpenCV BGR to RGB | |
| panorama_rgb = cv2.cvtColor(panorama_np, cv2.COLOR_BGR2RGB) | |
| # Create mask: where the image is completely black (or transparent) is where we want to paint | |
| # Assuming black borders from stitching | |
| gray = cv2.cvtColor(panorama_np, cv2.COLOR_BGR2GRAY) | |
| _, mask = cv2.threshold(gray, 1, 255, cv2.THRESH_BINARY) | |
| mask_inv = cv2.bitwise_not(mask) # white pixels (255) in mask_inv are the areas to paint | |
| # Resize to standard aspect ratio for diffusion (e.g., 1024x512) | |
| # Stitched panoramas might be very long and thin | |
| target_width = 1024 | |
| target_height = 512 | |
| pano_img = Image.fromarray(panorama_rgb).resize((target_width, target_height), Image.Resampling.LANCZOS) | |
| mask_img = Image.fromarray(mask_inv).resize((target_width, target_height), Image.Resampling.LANCZOS) | |
| prompt = "seamless 360 degree panoramic environment, equirectangular projection, highly detailed, photorealistic" | |
| negative_prompt = "distortion, blur, low quality, unnatural, artifacts, seams" | |
| output = inpaint_pipe( | |
| prompt=prompt, | |
| image=pano_img, | |
| mask_image=mask_img, | |
| negative_prompt=negative_prompt, | |
| num_inference_steps=20, | |
| ).images[0] | |
| return output | |
| def process_images(image_file_paths): | |
| if not image_file_paths or len(image_file_paths) < 2: | |
| return None, "Please upload at least 2 images for stitching.", "" | |
| paths = [img for img in image_file_paths] | |
| # 1. Stitch Images | |
| panorama_np, err = stitch_images(paths) | |
| if panorama_np is None: | |
| return None, f"Failed to stitch images: {err}. Please ensure there is enough overlap.", "" | |
| # 2. Outpaint missing areas | |
| final_image = outpaint_panorama(panorama_np) | |
| # We save it temporarily to load in HTML | |
| final_image.save("temp_pano.jpg") | |
| # 3. Interactive HTML Viewer (Pannellum) | |
| html_viewer = f""" | |
| <iframe width="100%" height="400px" allowfullscreen style="border-style:none;" | |
| src="https://cdn.pannellum.org/2.5/pannellum.htm#panorama=https://huggingface.co/spaces/example/image-to-360/resolve/main/temp_pano.jpg&autoLoad=true"> | |
| </iframe> | |
| <p><i>Note: In a true HF Space, the image URL needs to point to the saved file path on the space or be embedded as base64.</i></p> | |
| """ | |
| return final_image, "Panorama successfully created!", html_viewer | |
| with gr.Blocks(title="Image to 360 Environment") as demo: | |
| gr.Markdown("# 🌐 Image to 360 Environment Converter") | |
| gr.Markdown("Upload overlapping photos of an environment, and this tool will stitch them and use AI to extrapolate the missing areas (like the ceiling and floor) to create a seamless 360 equirectangular panorama.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_inputs = gr.File(file_count="multiple", type="filepath", label="Upload Overlapping Images") | |
| submit_btn = gr.Button("Generate 360 Environment", variant="primary") | |
| with gr.Column(): | |
| status_output = gr.Textbox(label="Status") | |
| raw_output = gr.Image(label="Generated Equirectangular Image", type="pil") | |
| with gr.Row(): | |
| html_output = gr.HTML(label="Interactive 360 Viewer") | |
| submit_btn.click( | |
| fn=process_images, | |
| inputs=[image_inputs], | |
| outputs=[raw_output, status_output, html_output] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |