nishanth-saka commited on
Commit
2a8000d
·
verified ·
1 Parent(s): ef6efdb

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +152 -0
  2. requirements.txt +8 -0
app.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import gradio as gr
3
+ import torch
4
+ import torch.nn as nn
5
+ import timm
6
+ import cv2
7
+ import numpy as np
8
+ from PIL import Image
9
+ import matplotlib.pyplot as plt
10
+ import os
11
+
12
+ # ===============================
13
+ # SIMPLE DPT MODEL (DEPTH ESTIMATION)
14
+ # ===============================
15
+ class SimpleDPT(nn.Module):
16
+ def __init__(self, backbone_name='vit_base_patch16_384'):
17
+ super(SimpleDPT, self).__init__()
18
+ self.backbone = timm.create_model(backbone_name, pretrained=True, features_only=True)
19
+ feature_info = self.backbone.feature_info
20
+ channels = [f['num_chs'] for f in feature_info]
21
+
22
+ self.decoder = nn.Sequential(
23
+ nn.Conv2d(channels[-1], 256, kernel_size=3, padding=1),
24
+ nn.ReLU(inplace=True),
25
+ nn.Conv2d(256, 128, kernel_size=3, padding=1),
26
+ nn.ReLU(inplace=True),
27
+ nn.Conv2d(128, 1, kernel_size=1)
28
+ )
29
+
30
+ def forward(self, x, target_size):
31
+ features = self.backbone(x)
32
+ x = features[-1]
33
+ depth = self.decoder(x)
34
+ depth = nn.functional.interpolate(depth, size=target_size, mode='bilinear', align_corners=False)
35
+ return depth
36
+
37
+ # ===============================
38
+ # DEPTH → NORMAL MAP
39
+ # ===============================
40
+ def depth_to_normal(depth):
41
+ dy, dx = np.gradient(depth)
42
+ normal = np.dstack((-dx, -dy, np.ones_like(depth)))
43
+ n = np.linalg.norm(normal, axis=2, keepdims=True)
44
+ normal /= (n + 1e-8)
45
+ normal = (normal + 1) / 2
46
+ return normal
47
+
48
+ # ===============================
49
+ # MAIN PROCESSING FUNCTION
50
+ # ===============================
51
+ def process_saree(base_image: Image.Image, pattern_image: Image.Image):
52
+ # Convert base to numpy
53
+ img_pil = base_image.convert("RGB")
54
+ img_np = np.array(img_pil)
55
+
56
+ # Prepare tensor
57
+ img_resized = img_pil.resize((384, 384))
58
+ img_tensor = torch.from_numpy(np.array(img_resized)).permute(2, 0, 1).unsqueeze(0).float() / 255.0
59
+ mean = torch.as_tensor([0.5, 0.5, 0.5], device=img_tensor.device).view(1, 3, 1, 1)
60
+ std = torch.as_tensor([0.5, 0.5, 0.5], device=img_tensor.device).view(1, 3, 1, 1)
61
+ img_tensor = (img_tensor - mean) / std
62
+
63
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
64
+ model = SimpleDPT(backbone_name='vit_base_patch16_384').to(device)
65
+ model.eval()
66
+
67
+ # Depth inference
68
+ with torch.no_grad():
69
+ target_size = img_pil.size[::-1]
70
+ depth_map = model(img_tensor.to(device), target_size=target_size)
71
+ depth_map = depth_map.squeeze().cpu().numpy()
72
+
73
+ # Normalize depth
74
+ depth_vis = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())
75
+
76
+ # Normal map
77
+ normal_map = depth_to_normal(depth_vis)
78
+
79
+ # Shading map (CLAHE)
80
+ img_lab = cv2.cvtColor(img_np, cv2.COLOR_RGB2LAB)
81
+ l_channel, _, _ = cv2.split(img_lab)
82
+ clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
83
+ l_clahe = clahe.apply(l_channel)
84
+ shading_map = l_clahe / 255.0
85
+
86
+ # GrabCut mask
87
+ img_bgr = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
88
+ grabcut_mask = np.zeros(img_bgr.shape[:2], np.uint8)
89
+ height, width = img_bgr.shape[:2]
90
+ margin = int(min(width, height) * 0.05)
91
+ rect = (margin, margin, width - 2 * margin, height - 2 * margin)
92
+ bgdModel = np.zeros((1, 65), np.float64)
93
+ fgdModel = np.zeros((1, 65), np.float64)
94
+ cv2.grabCut(img_bgr, grabcut_mask, rect, bgdModel, fgdModel, 5, cv2.GC_INIT_WITH_RECT)
95
+ mask = np.where((grabcut_mask == cv2.GC_FGD) | (grabcut_mask == cv2.GC_PR_FGD), 255, 0).astype(np.uint8)
96
+
97
+ # Tile pattern
98
+ pattern_np = np.array(pattern_image.convert("RGB"))
99
+ target_h, target_w = img_np.shape[:2]
100
+ pattern_h, pattern_w = pattern_np.shape[:2]
101
+ pattern_tiled = np.zeros((target_h, target_w, 3), dtype=np.uint8)
102
+ for y in range(0, target_h, pattern_h):
103
+ for x in range(0, target_w, pattern_w):
104
+ end_y = min(y + pattern_h, target_h)
105
+ end_x = min(x + pattern_w, target_w)
106
+ pattern_tiled[y:end_y, x:end_x] = pattern_np[0:(end_y - y), 0:(end_x - x)]
107
+
108
+ # Blend pattern
109
+ normal_map_loaded = normal_map.astype(np.float32)
110
+ shading_map_loaded = np.stack([shading_map] * 3, axis=-1)
111
+
112
+ alpha = 0.7
113
+ blended_shading = alpha * shading_map_loaded + (1 - alpha)
114
+
115
+ pattern_folded = pattern_tiled.astype(np.float32) / 255.0 * blended_shading
116
+ normal_boost = 0.5 + 0.5 * normal_map_loaded[..., 2:3]
117
+ pattern_folded *= normal_boost
118
+ pattern_folded = np.clip(pattern_folded, 0, 1)
119
+
120
+ # Clean mask and feather edges
121
+ mask_float = mask.astype(np.float32) / 255.0
122
+ kernel = np.ones((3, 3), np.uint8)
123
+ mask_clean = cv2.morphologyEx((mask_float * 255).astype(np.uint8), cv2.MORPH_OPEN, kernel)
124
+ mask_clean = cv2.morphologyEx(mask_clean, cv2.MORPH_CLOSE, kernel)
125
+ mask_clean = cv2.dilate(mask_clean, kernel, iterations=1)
126
+ mask_blurred = cv2.GaussianBlur(mask_clean, (15, 15), sigmaX=5, sigmaY=5)
127
+ mask_blurred[mask_blurred < 25] = 0
128
+ mask_blurred = mask_blurred.astype(np.float32) / 255.0
129
+
130
+ # Final RGBA
131
+ mask_stack = np.stack([mask_blurred] * 3, axis=-1)
132
+ pattern_final = pattern_folded * mask_stack
133
+ pattern_rgb = (pattern_final * 255).astype(np.uint8)
134
+ alpha_channel = (mask_blurred * 255).astype(np.uint8)
135
+ pattern_rgba = np.dstack((pattern_rgb, alpha_channel))
136
+
137
+ return Image.fromarray(pattern_rgba, mode="RGBA")
138
+
139
+ # ===============================
140
+ # GRADIO INTERFACE
141
+ # ===============================
142
+ iface = gr.Interface(
143
+ fn=process_saree,
144
+ inputs=[gr.Image(type="pil", label="Base Saree Image"),
145
+ gr.Image(type="pil", label="Pattern Image")],
146
+ outputs=gr.Image(type="pil", label="Final Saree Output"),
147
+ title="Saree Depth + Pattern Draping",
148
+ description="Upload base saree & pattern images to get depth-aware draped output (transparent edges, no black outline)."
149
+ )
150
+
151
+ if __name__ == "__main__":
152
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ timm
4
+ opencv-python
5
+ Pillow
6
+ matplotlib
7
+ tqdm
8
+ gradio