360Env / app.py
donemilioos's picture
new app.py
4c115ca verified
import gradio as gr
import cv2
import numpy as np
import os
from PIL import Image
from stitching import Stitcher
import torch
from diffusers import StableDiffusionInpaintPipeline
device = "cuda" if torch.cuda.is_available() else "cpu"
# Initialize models lazily or globally depending on the space environment
# Using a lightweight inpainting model
inpaint_pipe = None
def load_models():
global inpaint_pipe
if inpaint_pipe is None:
model_id = "stabilityai/stable-diffusion-2-inpainting"
inpaint_pipe = StableDiffusionInpaintPipeline.from_pretrained(
model_id,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
)
inpaint_pipe = inpaint_pipe.to(device)
def stitch_images(image_paths):
"""Stitches multiple images into a single panorama."""
stitcher = Stitcher(detector="sift", confidence_threshold=0.01)
try:
images = []
for path in image_paths:
img = cv2.imread(path)
if img is None:
return None, f"Could not read image at {path}. Ensure it is a valid image file."
# Optionally resize for performance/stability if images are extremely large
# height, width = img.shape[:2]
# max_dim = 2000
# if max(height, width) > max_dim:
# scale = max_dim / max(height, width)
# img = cv2.resize(img, (int(width * scale), int(height * scale)))
images.append(img)
panorama = stitcher.stitch(images)
return panorama, None
except Exception as e:
print(f"Stitching failed: {e}")
return None, str(e)
def outpaint_panorama(panorama_np):
"""Uses stable diffusion inpainting to fill black borders and missing sky/ground."""
load_models()
# Convert OpenCV BGR to RGB
panorama_rgb = cv2.cvtColor(panorama_np, cv2.COLOR_BGR2RGB)
# Create mask: where the image is completely black (or transparent) is where we want to paint
# Assuming black borders from stitching
gray = cv2.cvtColor(panorama_np, cv2.COLOR_BGR2GRAY)
_, mask = cv2.threshold(gray, 1, 255, cv2.THRESH_BINARY)
mask_inv = cv2.bitwise_not(mask) # white pixels (255) in mask_inv are the areas to paint
# Resize to standard aspect ratio for diffusion (e.g., 1024x512)
# Stitched panoramas might be very long and thin
target_width = 1024
target_height = 512
pano_img = Image.fromarray(panorama_rgb).resize((target_width, target_height), Image.Resampling.LANCZOS)
mask_img = Image.fromarray(mask_inv).resize((target_width, target_height), Image.Resampling.LANCZOS)
prompt = "seamless 360 degree panoramic environment, equirectangular projection, highly detailed, photorealistic"
negative_prompt = "distortion, blur, low quality, unnatural, artifacts, seams"
output = inpaint_pipe(
prompt=prompt,
image=pano_img,
mask_image=mask_img,
negative_prompt=negative_prompt,
num_inference_steps=20,
).images[0]
return output
def process_images(image_file_paths):
if not image_file_paths or len(image_file_paths) < 2:
return None, "Please upload at least 2 images for stitching.", ""
paths = [img for img in image_file_paths]
# 1. Stitch Images
panorama_np, err = stitch_images(paths)
if panorama_np is None:
return None, f"Failed to stitch images: {err}. Please ensure there is enough overlap.", ""
# 2. Outpaint missing areas
final_image = outpaint_panorama(panorama_np)
# We save it temporarily to load in HTML
final_image.save("temp_pano.jpg")
# 3. Interactive HTML Viewer (Pannellum)
html_viewer = f"""
<iframe width="100%" height="400px" allowfullscreen style="border-style:none;"
src="https://cdn.pannellum.org/2.5/pannellum.htm#panorama=https://huggingface.co/spaces/example/image-to-360/resolve/main/temp_pano.jpg&autoLoad=true">
</iframe>
<p><i>Note: In a true HF Space, the image URL needs to point to the saved file path on the space or be embedded as base64.</i></p>
"""
return final_image, "Panorama successfully created!", html_viewer
with gr.Blocks(title="Image to 360 Environment") as demo:
gr.Markdown("# 🌐 Image to 360 Environment Converter")
gr.Markdown("Upload overlapping photos of an environment, and this tool will stitch them and use AI to extrapolate the missing areas (like the ceiling and floor) to create a seamless 360 equirectangular panorama.")
with gr.Row():
with gr.Column():
image_inputs = gr.File(file_count="multiple", type="filepath", label="Upload Overlapping Images")
submit_btn = gr.Button("Generate 360 Environment", variant="primary")
with gr.Column():
status_output = gr.Textbox(label="Status")
raw_output = gr.Image(label="Generated Equirectangular Image", type="pil")
with gr.Row():
html_output = gr.HTML(label="Interactive 360 Viewer")
submit_btn.click(
fn=process_images,
inputs=[image_inputs],
outputs=[raw_output, status_output, html_output]
)
if __name__ == "__main__":
demo.launch()