enoky commited on
Commit
53f760e
·
verified ·
1 Parent(s): f98a0fe

run the LaMa model locally

Browse files
Files changed (2) hide show
  1. app.py +199 -223
  2. requirements.txt +2 -2
app.py CHANGED
@@ -1,224 +1,200 @@
1
- import gradio as gr
2
- import torch
3
- import numpy as np
4
- import cv2
5
- from PIL import Image
6
- from transformers import DPTForDepthEstimation, DPTImageProcessor
7
- from gradio_client import Client, handle_file
8
- import tempfile
9
- import os
10
-
11
- # === DEVICE ===
12
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
-
14
- # === DEPTH MODEL ===
15
- def load_depth_model():
16
- # DPTImageProcessor is the modern replacement for FeatureExtractor
17
- model = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to(device)
18
- processor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas")
19
- return model, processor
20
-
21
- @torch.no_grad()
22
- def estimate_depth(image_pil, model, processor):
23
- # Keep original size for restoration later
24
- original_size = image_pil.size # (width, height)
25
-
26
- # Preprocess (processor handles resizing internally for the model)
27
- inputs = processor(images=image_pil, return_tensors="pt").to(device)
28
-
29
- depth = model(**inputs).predicted_depth
30
-
31
- # Interpolate depth back to ORIGINAL image size
32
- depth = torch.nn.functional.interpolate(
33
- depth.unsqueeze(1),
34
- size=(original_size[1], original_size[0]), # torch expects (H, W)
35
- mode="bicubic",
36
- align_corners=False,
37
- ).squeeze().detach().cpu().numpy()
38
-
39
- # Normalize
40
- depth_min, depth_max = depth.min(), depth.max()
41
- if depth_max - depth_min > 0:
42
- return (depth - depth_min) / (depth_max - depth_min)
43
- return depth
44
-
45
- def generate_right_and_mask(image, shift_map):
46
- """
47
- Vectorized shift operation.
48
- shift_map: 2D array indicating how many pixels to shift left (positive) or right (negative).
49
- """
50
- height, width = image.shape[:2]
51
-
52
- # Create a grid of coordinates
53
- x_coords, y_coords = np.meshgrid(np.arange(width), np.arange(height))
54
-
55
- # Calculate target coordinates (shift pixels to the left for right eye)
56
- shift = shift_map.astype(int)
57
- target_x = x_coords - shift
58
-
59
- # Initialize output and mask
60
- right = np.zeros_like(image)
61
- mask = np.ones((height, width), dtype=np.uint8) * 255 # 255 = hole/inpainting area
62
-
63
- # Valid indices mask (ensure pixels land within image bounds)
64
- valid_mask = (target_x >= 0) & (target_x < width)
65
-
66
- # Flatten arrays for advanced indexing
67
- flat_y = y_coords[valid_mask]
68
- flat_x_target = target_x[valid_mask]
69
- flat_x_source = x_coords[valid_mask]
70
-
71
- # Assign pixels
72
- # Note: simple overwriting handles occlusions naively but effectively for this use case
73
- right[flat_y, flat_x_target] = image[flat_y, flat_x_source]
74
-
75
- # Update Mask: Areas that were written to are NOT holes (0)
76
- mask[flat_y, flat_x_target] = 0
77
-
78
- return right, mask
79
-
80
- def make_anaglyph(left, right):
81
- """
82
- Creates a Red-Cyan anaglyph.
83
- Left image provides the Red channel.
84
- Right image provides the Green and Blue channels.
85
- """
86
- # Convert to arrays
87
- l_arr = np.array(left)
88
- r_arr = np.array(right)
89
-
90
- # Create output array (same shape)
91
- anaglyph = np.zeros_like(l_arr)
92
-
93
- # Red channel from Left
94
- anaglyph[:, :, 0] = l_arr[:, :, 0]
95
-
96
- # Green and Blue channels from Right
97
- anaglyph[:, :, 1] = r_arr[:, :, 1]
98
- anaglyph[:, :, 2] = r_arr[:, :, 2]
99
-
100
- return Image.fromarray(anaglyph)
101
-
102
- # === LAMA INPAINTING (Via Gradio Client) ===
103
- # Note: You need a valid Space that accepts image + mask.
104
- try:
105
- lama_client = Client("asif-k/LaMa-Inpainting")
106
- except Exception as e:
107
- print(f"Could not connect to external LaMa client: {e}")
108
- lama_client = None
109
-
110
- def run_lama_inpainting(image_bgr, mask):
111
- if lama_client is None:
112
- print("LaMa client unavailable, returning unfilled image.")
113
- return image_bgr
114
-
115
- # Prepare files for Gradio Client
116
- # Convert BGR (OpenCV) to RGB for PIL
117
- img_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
118
-
119
- with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f_img, \
120
- tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f_mask:
121
-
122
- Image.fromarray(img_rgb).save(f_img.name)
123
- Image.fromarray(mask).save(f_mask.name)
124
-
125
- try:
126
- # Predict using the external space
127
- result_path = lama_client.predict(
128
- image=handle_file(f_img.name),
129
- mask=handle_file(f_mask.name),
130
- api_name="/predict"
131
- )
132
-
133
- # Result is a filepath
134
- res_img = Image.open(result_path)
135
- return cv2.cvtColor(np.array(res_img), cv2.COLOR_RGB2BGR)
136
-
137
- except Exception as e:
138
- print(f"Inpainting failed: {e}")
139
- return image_bgr # Return original with holes if fail
140
- finally:
141
- # Cleanup
142
- os.remove(f_img.name)
143
- os.remove(f_mask.name)
144
-
145
- # === APP LOGIC ===
146
- depth_model, depth_processor = load_depth_model()
147
-
148
- def stereo_pipeline(image_pil, divergence, convergence):
149
- if image_pil is None:
150
- return None, None
151
-
152
- image_cv = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
153
-
154
- # 1. Estimate Depth (0.0 far to 1.0 near)
155
- depth = estimate_depth(image_pil, depth_model, depth_processor)
156
-
157
- # 2. Calculate Shift Map
158
- # Divergence: Overall separation strength (pixels)
159
- # Convergence: The depth plane that stays still (0.0 - 1.0)
160
- # Result:
161
- # Positive shift (Leftwards) = Pop out of screen (Near objects)
162
- # Negative shift (Rightwards) = Go into screen (Far objects)
163
- shift = (depth - convergence) * divergence
164
-
165
- # 3. Shift Pixels
166
- right_img, mask = generate_right_and_mask(image_cv, shift)
167
-
168
- # 4. Inpaint Holes
169
- # Pass the mask where 255 indicates holes to be filled
170
- right_filled = run_lama_inpainting(right_img, mask)
171
-
172
- left = image_pil
173
- right = Image.fromarray(cv2.cvtColor(right_filled, cv2.COLOR_BGR2RGB))
174
-
175
- # === Combine into Side-by-Side ===
176
- width, height = left.size
177
- combined_image = Image.new('RGB', (width * 2, height))
178
- combined_image.paste(left, (0, 0))
179
- combined_image.paste(right, (width, 0))
180
-
181
- # === Create Anaglyph ===
182
- anaglyph_image = make_anaglyph(left, right)
183
-
184
- return combined_image, anaglyph_image
185
-
186
- # === GRADIO UI ===
187
- with gr.Blocks(title="2D to 3D Stereo") as demo:
188
- gr.Markdown("## 2D to 3D Stereo Generator")
189
- gr.Markdown("Generates a side-by-side stereo pair and anaglyph using Depth Estimation and LaMa Inpainting.")
190
-
191
- with gr.Row():
192
- with gr.Column(scale=1):
193
- input_img = gr.Image(type="pil", label="Input Image", height=480)
194
-
195
- # === Controls ===
196
- with gr.Group():
197
- gr.Markdown("### 3D Controls")
198
- divergence_slider = gr.Slider(
199
- minimum=0, maximum=100, value=30, step=1,
200
- label="3D Strength (Divergence)",
201
- info="Max pixel separation. Higher = Deeper 3D effect."
202
- )
203
- convergence_slider = gr.Slider(
204
- minimum=0.0, maximum=1.0, value=0.1, step=0.05,
205
- label="Focus Plane (Convergence)",
206
- info="0.0 = Background at screen depth. 0.5 = Mid-range at screen. 1.0 = Foreground at screen."
207
- )
208
-
209
- btn = gr.Button("Generate 3D", variant="primary")
210
-
211
- with gr.Column(scale=1):
212
- out_anaglyph = gr.Image(label="Anaglyph (Red/Cyan)", height=480)
213
-
214
- with gr.Row():
215
- out_stereo = gr.Image(label="Side-by-Side Stereo Pair", height=400)
216
-
217
- btn.click(
218
- fn=stereo_pipeline,
219
- inputs=[input_img, divergence_slider, convergence_slider],
220
- outputs=[out_stereo, out_anaglyph]
221
- )
222
-
223
- if __name__ == "__main__":
224
  demo.launch()
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ import cv2
5
+ from PIL import Image
6
+ from transformers import DPTForDepthEstimation, DPTImageProcessor
7
+ from huggingface_hub import hf_hub_download
8
+ import os
9
+
10
+ # === DEVICE ===
11
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+ print(f"Running on device: {device}")
13
+
14
+ # === LOAD MODELS ===
15
+ def load_models():
16
+ print("Loading Depth Model...")
17
+ # 1. Depth Model
18
+ depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to(device)
19
+ depth_processor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas")
20
+
21
+ print("Loading LaMa Inpainting Model...")
22
+ # 2. LaMa Inpainting Model (TorchScript)
23
+ # We download the JIT traced model which is self-contained
24
+ model_path = hf_hub_download(repo_id="smartywu/big-lama", filename="big-lama.pt")
25
+ lama_model = torch.jit.load(model_path).to(device)
26
+ lama_model.eval()
27
+
28
+ return depth_model, depth_processor, lama_model
29
+
30
+ # Load models once at startup
31
+ depth_model, depth_processor, lama_model = load_models()
32
+
33
+ # === DEPTH ESTIMATION ===
34
+ @torch.no_grad()
35
+ def estimate_depth(image_pil, model, processor):
36
+ original_size = image_pil.size
37
+ inputs = processor(images=image_pil, return_tensors="pt").to(device)
38
+ depth = model(**inputs).predicted_depth
39
+
40
+ depth = torch.nn.functional.interpolate(
41
+ depth.unsqueeze(1),
42
+ size=(original_size[1], original_size[0]),
43
+ mode="bicubic",
44
+ align_corners=False,
45
+ ).squeeze().detach().cpu().numpy()
46
+
47
+ depth_min, depth_max = depth.min(), depth.max()
48
+ if depth_max - depth_min > 0:
49
+ return (depth - depth_min) / (depth_max - depth_min)
50
+ return depth
51
+
52
+ # === STEREO GENERATION LOGIC ===
53
+ def generate_right_and_mask(image, shift_map):
54
+ height, width = image.shape[:2]
55
+ x_coords, y_coords = np.meshgrid(np.arange(width), np.arange(height))
56
+ shift = shift_map.astype(int)
57
+ target_x = x_coords - shift
58
+
59
+ right = np.zeros_like(image)
60
+ # Mask: 1 (or 255) means HOLE/MISSING info.
61
+ # Initialize as all holes (255)
62
+ mask = np.ones((height, width), dtype=np.float32)
63
+
64
+ valid_mask = (target_x >= 0) & (target_x < width)
65
+ flat_y = y_coords[valid_mask]
66
+ flat_x_target = target_x[valid_mask]
67
+ flat_x_source = x_coords[valid_mask]
68
+
69
+ right[flat_y, flat_x_target] = image[flat_y, flat_x_source]
70
+ # Mark written pixels as valid (0)
71
+ mask[flat_y, flat_x_target] = 0.0
72
+
73
+ return right, mask
74
+
75
+ # === LOCAL INPAINTING ===
76
+ @torch.no_grad()
77
+ def run_local_lama(image_bgr, mask_float):
78
+ """
79
+ Runs LaMa locally.
80
+ image_bgr: HxWx3 uint8 numpy array
81
+ mask_float: HxW float32 numpy array (1.0 = hole, 0.0 = valid)
82
+ """
83
+ # 1. Resize to be divisible by 8 (LaMa requirement)
84
+ h, w = image_bgr.shape[:2]
85
+ new_h = (h // 8) * 8
86
+ new_w = (w // 8) * 8
87
+
88
+ img_resized = cv2.resize(image_bgr, (new_w, new_h))
89
+ mask_resized = cv2.resize(mask_float, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
90
+
91
+ # 2. Convert to Torch Tensors
92
+ # Image: (1, 3, H, W), RGB, 0-1
93
+ img_t = torch.from_numpy(img_resized).float().permute(2, 0, 1).unsqueeze(0) / 255.0
94
+ # Swap BGR to RGB
95
+ img_t = img_t[:, [2, 1, 0], :, :]
96
+
97
+ # Mask: (1, 1, H, W), 0-1
98
+ mask_t = torch.from_numpy(mask_resized).float().unsqueeze(0).unsqueeze(0)
99
+ # Binary threshold just in case
100
+ mask_t = (mask_t > 0.5).float()
101
+
102
+ img_t = img_t.to(device)
103
+ mask_t = mask_t.to(device)
104
+
105
+ # 3. Inference
106
+ inpainted_t = lama_model(img_t, mask_t)
107
+
108
+ # 4. Post-process
109
+ inpainted = inpainted_t[0].permute(1, 2, 0).cpu().numpy()
110
+ inpainted = np.clip(inpainted * 255, 0, 255).astype(np.uint8)
111
+
112
+ # Swap back RGB to BGR
113
+ inpainted = cv2.cvtColor(inpainted, cv2.COLOR_RGB2BGR)
114
+
115
+ # Resize back to original if needed
116
+ if new_h != h or new_w != w:
117
+ inpainted = cv2.resize(inpainted, (w, h))
118
+
119
+ return inpainted
120
+
121
+ def make_anaglyph(left, right):
122
+ l_arr = np.array(left)
123
+ r_arr = np.array(right)
124
+ anaglyph = np.zeros_like(l_arr)
125
+ anaglyph[:, :, 0] = l_arr[:, :, 0]
126
+ anaglyph[:, :, 1] = r_arr[:, :, 1]
127
+ anaglyph[:, :, 2] = r_arr[:, :, 2]
128
+ return Image.fromarray(anaglyph)
129
+
130
+ # === PIPELINE ===
131
+ def stereo_pipeline(image_pil, divergence, convergence):
132
+ if image_pil is None:
133
+ return None, None
134
+
135
+ # Convert to BGR for OpenCV processing
136
+ image_cv = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
137
+
138
+ # 1. Depth
139
+ depth = estimate_depth(image_pil, depth_model, depth_processor)
140
+
141
+ # 2. Shift Map
142
+ shift = (depth - convergence) * divergence
143
+
144
+ # 3. Warping
145
+ right_img, mask = generate_right_and_mask(image_cv, shift)
146
+
147
+ # 4. Inpainting (Local)
148
+ right_filled = run_local_lama(right_img, mask)
149
+
150
+ left = image_pil
151
+ right = Image.fromarray(cv2.cvtColor(right_filled, cv2.COLOR_BGR2RGB))
152
+
153
+ # 5. Composition
154
+ width, height = left.size
155
+ combined_image = Image.new('RGB', (width * 2, height))
156
+ combined_image.paste(left, (0, 0))
157
+ combined_image.paste(right, (width, 0))
158
+
159
+ anaglyph_image = make_anaglyph(left, right)
160
+
161
+ return combined_image, anaglyph_image
162
+
163
+ # === GRADIO UI ===
164
+ with gr.Blocks(title="2D to 3D Stereo") as demo:
165
+ gr.Markdown("## 2D to 3D Stereo Generator (Fully Local)")
166
+ gr.Markdown("Generates stereo pairs using Depth Estimation and **Local LaMa Inpainting**. No external APIs required.")
167
+
168
+ with gr.Row():
169
+ with gr.Column(scale=1):
170
+ input_img = gr.Image(type="pil", label="Input Image", height=480)
171
+
172
+ with gr.Group():
173
+ gr.Markdown("### 3D Controls")
174
+ divergence_slider = gr.Slider(
175
+ minimum=0, maximum=100, value=30, step=1,
176
+ label="3D Strength (Divergence)",
177
+ info="Max pixel separation."
178
+ )
179
+ convergence_slider = gr.Slider(
180
+ minimum=0.0, maximum=1.0, value=0.1, step=0.05,
181
+ label="Focus Plane (Convergence)",
182
+ info="0.0 = Background at screen. 1.0 = Foreground at screen."
183
+ )
184
+
185
+ btn = gr.Button("Generate 3D", variant="primary")
186
+
187
+ with gr.Column(scale=1):
188
+ out_anaglyph = gr.Image(label="Anaglyph (Red/Cyan)", height=480)
189
+
190
+ with gr.Row():
191
+ out_stereo = gr.Image(label="Side-by-Side Stereo Pair", height=400)
192
+
193
+ btn.click(
194
+ fn=stereo_pipeline,
195
+ inputs=[input_img, divergence_slider, convergence_slider],
196
+ outputs=[out_stereo, out_anaglyph]
197
+ )
198
+
199
+ if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  demo.launch()
requirements.txt CHANGED
@@ -1,8 +1,8 @@
1
  gradio
2
- gradio_client
3
  torch
4
  numpy
5
  opencv-python
6
  pillow
7
  transformers
8
- scipy
 
 
1
  gradio
 
2
  torch
3
  numpy
4
  opencv-python
5
  pillow
6
  transformers
7
+ scipy
8
+ huggingface_hub