File size: 8,198 Bytes
c8b42eb
 
3991736
c8b42eb
 
 
3991736
94dea5c
 
 
8683722
 
 
7f9c5ee
b91a473
8683722
94dea5c
1b3abf3
 
b7008ff
7b12cd5
 
 
 
57a236c
7b12cd5
 
 
 
 
94dea5c
 
c8b42eb
 
 
 
 
 
 
 
 
7f9c5ee
c8b42eb
 
 
7f9c5ee
c8b42eb
7f9c5ee
c8b42eb
 
7f9c5ee
c8b42eb
 
7f9c5ee
 
c8b42eb
 
7f9c5ee
c8b42eb
 
7f9c5ee
c8b42eb
b91a473
c8b42eb
 
 
 
 
 
 
 
 
7f9c5ee
 
c8b42eb
 
 
 
b91a473
 
 
c8b42eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74be47f
 
c8b42eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1a64fe9
c8b42eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
import cv2
import flow_vis
import gradio as gr
import numpy as np
import torch
from PIL import Image

import sys
import subprocess

import sys
import subprocess
import importlib
from functools import lru_cache
import spaces

try:
    import uniception
    import uniflowmatch
except:
    subprocess.check_call(["/bin/bash", "./install_package.sh"])
    
    # Optional: explicitly reload sys.path (only needed if install path changed)
    import site
    importlib.invalidate_caches()
    site.main()  # reload site-packages
    
    # Now try importing again
    uniception = importlib.import_module("uniception")
    uniflowmatch = importlib.import_module("uniflowmatch")


from uniflowmatch.models.ufm import (
    UniFlowMatchClassificationRefinement,
    UniFlowMatchConfidence,
)
from uniflowmatch.utils.viz import warp_image_with_flow

# Global model variable
USE_REFINEMENT_MODEL = False

@lru_cache(maxsize=2)
def initialize_model(use_refinement: bool = False):
    """Initialize the model - call this once at startup"""
    try:
        if use_refinement:
            print("Loading UFM Refinement model from infinity1096/UFM-Refine...")
            model_obj = UniFlowMatchClassificationRefinement.from_pretrained("infinity1096/UFM-Refine")
        else:
            print("Loading UFM Base model from infinity1096/UFM-Base...")
            model_obj = UniFlowMatchConfidence.from_pretrained("infinity1096/UFM-Base")

        # Set model to evaluation mode
        if hasattr(model_obj, "eval"):
            model_obj.eval()

        print("Model loaded successfully!")
        return True, model_obj
    except Exception as e:
        print(f"Error loading model: {e}")
        return False, None

@spaces.GPU
def process_images(source_image, target_image, model_type_choice):
    """
    Process two uploaded images and return visualizations
    """
    if source_image is None or target_image is None:
        return None, None, None, "Please upload both images."

    # Reinitialize model if type has changed
    current_refinement = model_type_choice == "Refinement Model"
    print(f"Switching to {model_type_choice}...")
    ret, model = initialize_model(current_refinement)

    if model is None:
        return None, None, None, "Model not loaded. Please restart the application."

    model = model.to("cuda" if torch.cuda.is_available() else "cpu")
    use_gpu = torch.cuda.is_available()

    try:
        # Convert PIL images to numpy arrays
        source_np = np.array(source_image)
        target_np = np.array(target_image)

        # Ensure images are RGB
        if len(source_np.shape) == 3 and source_np.shape[2] == 3:
            source_rgb = source_np
        else:
            source_rgb = cv2.cvtColor(source_np, cv2.COLOR_BGR2RGB)

        if len(target_np.shape) == 3 and target_np.shape[2] == 3:
            target_rgb = target_np
        else:
            target_rgb = cv2.cvtColor(target_np, cv2.COLOR_BGR2RGB)

        print(f"Processing images with shapes: Source {source_rgb.shape}, Target {target_rgb.shape}")

        # === Predict Correspondences ===
        with torch.no_grad():
            result = model.predict_correspondences_batched(
                source_image=torch.from_numpy(source_rgb).to("cuda" if use_gpu else "cpu"),
                target_image=torch.from_numpy(target_rgb).to("cuda" if use_gpu else "cpu"),
            )

            # Extract results based on your model's output structure
            flow_output = result.flow.flow_output[0].cpu().numpy()
            covisibility = result.covisibility.mask[0].cpu().numpy()

        print(f"Flow output shape: {flow_output.shape}")
        print(f"Covisibility shape: {covisibility.shape}")

        # === Create Visualizations ===

        # 1. Flow visualization
        flow_vis_image = flow_vis.flow_to_color(flow_output.transpose(1, 2, 0))
        flow_pil = Image.fromarray(flow_vis_image.astype(np.uint8))

        # 2. Covisibility visualization - direct gray image
        covisibility_gray = (covisibility * 255).astype(np.uint8)
        covisibility_pil = Image.fromarray(covisibility_gray, mode="L")

        # 3. Warped image using actual warp function
        warped_image = warp_image_with_flow(source_rgb, None, target_rgb, flow_output.transpose(1, 2, 0))
        warped_image = covisibility[..., None] * warped_image + (1 - covisibility[..., None]) * 255 * np.ones_like(
            warped_image
        )
        warped_image = (warped_image / 255.0).clip(0, 1)
        warped_pil = Image.fromarray((warped_image * 255).astype(np.uint8))

        status_msg = f"Processing completed with {model_type_choice}"

        return flow_pil, covisibility_pil, warped_pil, status_msg

    except Exception as e:
        error_msg = f"Error processing images: {str(e)}"
        print(error_msg)
        return None, None, None, error_msg


def create_demo():
    """Create the Gradio interface"""

    with gr.Blocks(title="UniFlowMatch Demo") as demo:
        gr.Markdown("# UniFlowMatch Demo")
        gr.Markdown("Upload two images to see optical flow visualization")

        # Input section
        with gr.Row():
            source_input = gr.Image(label="Source Image", type="pil")
            target_input = gr.Image(label="Target Image", type="pil")

        # Model selection
        model_type = gr.Radio(choices=["Base Model", "Refinement Model"], value="Base Model", label="Model Type")

        # Process button
        process_btn = gr.Button("Process Images")

        # Status
        status_output = gr.Textbox(label="Status", interactive=False)

        # Output section
        with gr.Row():
            flow_output = gr.Image(label="Flow Visualization")
            covisibility_output = gr.Image(label="Covisibility Mask")
            warped_output = gr.Image(label="Warped Target Image")

        # Example images
        gr.Examples(
            examples=[
                ["examples/image_pairs/fire_academy_0.png", "examples/image_pairs/fire_academy_1.png"],
                ["examples/image_pairs/scene_0.png", "examples/image_pairs/scene_1.png"],
                ["examples/image_pairs/bike_0.png", "examples/image_pairs/bike_1.png"],
                ["examples/image_pairs/cook_0.png", "examples/image_pairs/cook_1.png"],
                ["examples/image_pairs/building_0.png", "examples/image_pairs/building_1.png"],
            ],
            inputs=[source_input, target_input],
            label="Example Image Pairs",
        )

        # Event handlers
        process_btn.click(
            fn=process_images,
            inputs=[source_input, target_input, model_type],
            outputs=[flow_output, covisibility_output, warped_output, status_output],
        )

        # Auto-process when both images are uploaded
        def auto_process(source, target, model_choice):
            if source is not None and target is not None:
                return process_images(source, target, model_choice)
            return None, None, None, "Upload both images to start processing."

        for input_component in [source_input, target_input, model_type]:
            input_component.change(
                fn=auto_process,
                inputs=[source_input, target_input, model_type],
                outputs=[flow_output, covisibility_output, warped_output, status_output],
            )

    return demo


if __name__ == "__main__":
    # Initialize model
    print("Initializing UniFlowMatch model...")
    model_loaded = initialize_model(use_refinement=False)  # Start with base model

    if not model_loaded:
        print("Error: Model failed to load. Please check your model installation and HuggingFace access.")
        print("Make sure you have:")
        print("1. Installed uniflowmatch package")
        print("2. Have internet access for downloading pretrained models")
        print("3. All required dependencies installed")
        exit(1)

    # Create and launch demo
    demo = create_demo()
    demo.launch(
        share=True,  # Set to True to create a public link
        server_name="0.0.0.0",  # Allow external connections
        server_port=7860,  # Default Gradio port
        show_error=True,
    )