hobrt commited on
Commit
6ade92c
·
verified ·
1 Parent(s): 71edca1

Create simple_lama.py

Browse files
Files changed (1) hide show
  1. simple_lama.py +104 -0
simple_lama.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ import onnxruntime
5
+ import urllib.request
6
+
7
+ class SimpleLama:
8
+ def __init__(self, device='cpu'):
9
+ self.model_url = "https://huggingface.co/Carve/LaMa-ONNX/resolve/main/lama_fp32.onnx"
10
+ self.model_path = "big-lama.onnx"
11
+
12
+ if not os.path.exists(self.model_path):
13
+ print(f"Downloading LaMa model to {self.model_path} (this happens once)...")
14
+ try:
15
+ # Add headers to avoid 403 Forbidden errors if sites block script bots
16
+ opener = urllib.request.build_opener()
17
+ opener.addheaders = [('User-agent', 'Mozilla/5.0')]
18
+ urllib.request.install_opener(opener)
19
+
20
+ urllib.request.urlretrieve(self.model_url, self.model_path)
21
+ print("Download complete.")
22
+ except Exception as e:
23
+ print(f"Failed to download model: {e}")
24
+ raise
25
+
26
+ print("Loading LaMa model...")
27
+ providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if device == 'cuda' else ['CPUExecutionProvider']
28
+ self.session = onnxruntime.InferenceSession(self.model_path, providers=providers)
29
+
30
+ def pad_img_to_modulo(self, img, mod):
31
+ h, w = img.shape[:2]
32
+ h = int(np.ceil(h / mod) * mod)
33
+ w = int(np.ceil(w / mod) * mod)
34
+ return cv2.copyMakeBorder(img, 0, h - img.shape[0], 0, w - img.shape[1], cv2.BORDER_REFLECT)
35
+
36
+ def predict(self, image, mask):
37
+ # 1. Find bounding box of the mask
38
+ # If mask is empty, return original
39
+ if np.max(mask) == 0:
40
+ return image
41
+
42
+ rows, cols = np.where(mask > 0)
43
+ y1, y2 = np.min(rows), np.max(rows)
44
+ x1, x2 = np.min(cols), np.max(cols)
45
+
46
+ # 2. Add padding to context
47
+ # Increased padding from 50 to 200 to give the AI more context of the background
48
+ # This helps preventing "flat color" patches on gradients or textured backgrounds.
49
+ pad = 200
50
+ y1 = max(0, y1 - pad)
51
+ y2 = min(image.shape[0], y2 + pad)
52
+ x1 = max(0, x1 - pad)
53
+ x2 = min(image.shape[1], x2 + pad)
54
+
55
+ # 3. Crop
56
+ crop_img = image[y1:y2, x1:x2]
57
+ crop_mask = mask[y1:y2, x1:x2]
58
+
59
+ crop_h, crop_w = crop_img.shape[:2]
60
+
61
+ # 4. Resize to 512x512 (Model Expectation)
62
+ target_size = (512, 512)
63
+ img_resized = cv2.resize(crop_img, target_size, interpolation=cv2.INTER_AREA)
64
+ mask_resized = cv2.resize(crop_mask, target_size, interpolation=cv2.INTER_NEAREST)
65
+
66
+ # Prepare for ONNX
67
+ img_onnx = img_resized.astype(np.float32) / 255.0
68
+ img_onnx = img_onnx.transpose(2, 0, 1) # HWC -> CHW
69
+ img_onnx = img_onnx[None, ...] # Add batch dim
70
+
71
+ mask_onnx = mask_resized.astype(np.float32) / 255.0
72
+ mask_onnx = (mask_onnx > 0).astype(np.float32)
73
+ if len(mask_onnx.shape) == 2:
74
+ mask_onnx = mask_onnx[None, ...]
75
+ mask_onnx = mask_onnx[None, ...] # Add batch/channel dim
76
+
77
+ # Run inference
78
+ outputs = self.session.run(None, {'image': img_onnx, 'mask': mask_onnx})
79
+ output = outputs[0][0] # Remove batch dim
80
+
81
+ # Post-process
82
+ output = output.transpose(1, 2, 0) # CHW -> HWC
83
+
84
+ # Check if output is already 0-255 (The model from HF seems to output 0-255 directly)
85
+ if output.max() > 2.0:
86
+ # Already 0-255
87
+ output = output.clip(0, 255).astype(np.uint8)
88
+ else:
89
+ # It's 0-1, so scale it
90
+ output = (output * 255.0).clip(0, 255).astype(np.uint8)
91
+
92
+ # 5. Resize back to crop size
93
+ output_resized = cv2.resize(output, (crop_w, crop_h), interpolation=cv2.INTER_CUBIC)
94
+
95
+ # 6. Paste back
96
+ result = image.copy()
97
+ # We only want to paste the part that was masked + blended, but simple paste is okay for now
98
+ # Ideally we blend, but LaMa does inpainting so direct paste usually works.
99
+ # To avoid seams on the square border, we can use the mask to only update the painted area + some context,
100
+ # but let's just paste the whole square context for now as LaMa regenerates the whole context.
101
+
102
+ result[y1:y2, x1:x2] = output_resized
103
+
104
+ return result