Spaces:
Running
on
Zero
Running
on
Zero
| import cv2 | |
| import flow_vis | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| import sys | |
| import subprocess | |
| import sys | |
| import subprocess | |
| import importlib | |
| from functools import lru_cache | |
| import spaces | |
| try: | |
| import uniception | |
| import uniflowmatch | |
| except: | |
| subprocess.check_call(["/bin/bash", "./install_package.sh"]) | |
| # Optional: explicitly reload sys.path (only needed if install path changed) | |
| import site | |
| importlib.invalidate_caches() | |
| site.main() # reload site-packages | |
| # Now try importing again | |
| uniception = importlib.import_module("uniception") | |
| uniflowmatch = importlib.import_module("uniflowmatch") | |
| from uniflowmatch.models.ufm import ( | |
| UniFlowMatchClassificationRefinement, | |
| UniFlowMatchConfidence, | |
| ) | |
| from uniflowmatch.utils.viz import warp_image_with_flow | |
| # Global model variable | |
| USE_REFINEMENT_MODEL = False | |
| def initialize_model(use_refinement: bool = False): | |
| """Initialize the model - call this once at startup""" | |
| try: | |
| if use_refinement: | |
| print("Loading UFM Refinement model from infinity1096/UFM-Refine...") | |
| model_obj = UniFlowMatchClassificationRefinement.from_pretrained("infinity1096/UFM-Refine") | |
| else: | |
| print("Loading UFM Base model from infinity1096/UFM-Base...") | |
| model_obj = UniFlowMatchConfidence.from_pretrained("infinity1096/UFM-Base") | |
| # Set model to evaluation mode | |
| if hasattr(model_obj, "eval"): | |
| model_obj.eval() | |
| print("Model loaded successfully!") | |
| return True, model_obj | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| return False, None | |
| def process_images(source_image, target_image, model_type_choice): | |
| """ | |
| Process two uploaded images and return visualizations | |
| """ | |
| if source_image is None or target_image is None: | |
| return None, None, None, "Please upload both images." | |
| # Reinitialize model if type has changed | |
| current_refinement = model_type_choice == "Refinement Model" | |
| print(f"Switching to {model_type_choice}...") | |
| ret, model = initialize_model(current_refinement) | |
| if model is None: | |
| return None, None, None, "Model not loaded. Please restart the application." | |
| model = model.to("cuda" if torch.cuda.is_available() else "cpu") | |
| use_gpu = torch.cuda.is_available() | |
| try: | |
| # Convert PIL images to numpy arrays | |
| source_np = np.array(source_image) | |
| target_np = np.array(target_image) | |
| # Ensure images are RGB | |
| if len(source_np.shape) == 3 and source_np.shape[2] == 3: | |
| source_rgb = source_np | |
| else: | |
| source_rgb = cv2.cvtColor(source_np, cv2.COLOR_BGR2RGB) | |
| if len(target_np.shape) == 3 and target_np.shape[2] == 3: | |
| target_rgb = target_np | |
| else: | |
| target_rgb = cv2.cvtColor(target_np, cv2.COLOR_BGR2RGB) | |
| print(f"Processing images with shapes: Source {source_rgb.shape}, Target {target_rgb.shape}") | |
| # === Predict Correspondences === | |
| with torch.no_grad(): | |
| result = model.predict_correspondences_batched( | |
| source_image=torch.from_numpy(source_rgb).to("cuda" if use_gpu else "cpu"), | |
| target_image=torch.from_numpy(target_rgb).to("cuda" if use_gpu else "cpu"), | |
| ) | |
| # Extract results based on your model's output structure | |
| flow_output = result.flow.flow_output[0].cpu().numpy() | |
| covisibility = result.covisibility.mask[0].cpu().numpy() | |
| print(f"Flow output shape: {flow_output.shape}") | |
| print(f"Covisibility shape: {covisibility.shape}") | |
| # === Create Visualizations === | |
| # 1. Flow visualization | |
| flow_vis_image = flow_vis.flow_to_color(flow_output.transpose(1, 2, 0)) | |
| flow_pil = Image.fromarray(flow_vis_image.astype(np.uint8)) | |
| # 2. Covisibility visualization - direct gray image | |
| covisibility_gray = (covisibility * 255).astype(np.uint8) | |
| covisibility_pil = Image.fromarray(covisibility_gray, mode="L") | |
| # 3. Warped image using actual warp function | |
| warped_image = warp_image_with_flow(source_rgb, None, target_rgb, flow_output.transpose(1, 2, 0)) | |
| warped_image = covisibility[..., None] * warped_image + (1 - covisibility[..., None]) * 255 * np.ones_like( | |
| warped_image | |
| ) | |
| warped_image = (warped_image / 255.0).clip(0, 1) | |
| warped_pil = Image.fromarray((warped_image * 255).astype(np.uint8)) | |
| status_msg = f"Processing completed with {model_type_choice}" | |
| return flow_pil, covisibility_pil, warped_pil, status_msg | |
| except Exception as e: | |
| error_msg = f"Error processing images: {str(e)}" | |
| print(error_msg) | |
| return None, None, None, error_msg | |
| def create_demo(): | |
| """Create the Gradio interface""" | |
| with gr.Blocks(title="UniFlowMatch Demo") as demo: | |
| gr.Markdown("# UniFlowMatch Demo") | |
| gr.Markdown("Upload two images to see optical flow visualization") | |
| # Input section | |
| with gr.Row(): | |
| source_input = gr.Image(label="Source Image", type="pil") | |
| target_input = gr.Image(label="Target Image", type="pil") | |
| # Model selection | |
| model_type = gr.Radio(choices=["Base Model", "Refinement Model"], value="Base Model", label="Model Type") | |
| # Process button | |
| process_btn = gr.Button("Process Images") | |
| # Status | |
| status_output = gr.Textbox(label="Status", interactive=False) | |
| # Output section | |
| with gr.Row(): | |
| flow_output = gr.Image(label="Flow Visualization") | |
| covisibility_output = gr.Image(label="Covisibility Mask") | |
| warped_output = gr.Image(label="Warped Target Image") | |
| # Example images | |
| gr.Examples( | |
| examples=[ | |
| ["examples/image_pairs/fire_academy_0.png", "examples/image_pairs/fire_academy_1.png"], | |
| ["examples/image_pairs/scene_0.png", "examples/image_pairs/scene_1.png"], | |
| ["examples/image_pairs/bike_0.png", "examples/image_pairs/bike_1.png"], | |
| ["examples/image_pairs/cook_0.png", "examples/image_pairs/cook_1.png"], | |
| ["examples/image_pairs/building_0.png", "examples/image_pairs/building_1.png"], | |
| ], | |
| inputs=[source_input, target_input], | |
| label="Example Image Pairs", | |
| ) | |
| # Event handlers | |
| process_btn.click( | |
| fn=process_images, | |
| inputs=[source_input, target_input, model_type], | |
| outputs=[flow_output, covisibility_output, warped_output, status_output], | |
| ) | |
| # Auto-process when both images are uploaded | |
| def auto_process(source, target, model_choice): | |
| if source is not None and target is not None: | |
| return process_images(source, target, model_choice) | |
| return None, None, None, "Upload both images to start processing." | |
| for input_component in [source_input, target_input, model_type]: | |
| input_component.change( | |
| fn=auto_process, | |
| inputs=[source_input, target_input, model_type], | |
| outputs=[flow_output, covisibility_output, warped_output, status_output], | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| # Initialize model | |
| print("Initializing UniFlowMatch model...") | |
| model_loaded = initialize_model(use_refinement=False) # Start with base model | |
| if not model_loaded: | |
| print("Error: Model failed to load. Please check your model installation and HuggingFace access.") | |
| print("Make sure you have:") | |
| print("1. Installed uniflowmatch package") | |
| print("2. Have internet access for downloading pretrained models") | |
| print("3. All required dependencies installed") | |
| exit(1) | |
| # Create and launch demo | |
| demo = create_demo() | |
| demo.launch( | |
| share=True, # Set to True to create a public link | |
| server_name="0.0.0.0", # Allow external connections | |
| server_port=7860, # Default Gradio port | |
| show_error=True, | |
| ) | |