AhsanAftab commited on
Commit
64e9eb5
·
verified ·
1 Parent(s): 112d6ec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +169 -161
app.py CHANGED
@@ -1,162 +1,170 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- import torchvision.models as models
5
- from flask import Flask, request, jsonify
6
- from flask_cors import CORS
7
- import numpy as np
8
- import cv2
9
- import base64
10
- from io import BytesIO
11
- from PIL import Image
12
-
13
- app = Flask(__name__)
14
- CORS(app)
15
-
16
- DEVICE = torch.device('cpu') # Force CPU for HF Free Tier
17
- MODEL_PATH = "best_model_fixed.pth" # Upload your trained .pth file
18
-
19
- class InpaintingGenerator(nn.Module):
20
- def __init__(self, input_channels=4):
21
- super().__init__()
22
- resnet = models.resnet34(weights=None)
23
-
24
- self.enc1 = nn.Sequential(
25
- nn.Conv2d(input_channels, 64, kernel_size=7, stride=2, padding=3, bias=False),
26
- resnet.bn1, resnet.relu
27
- )
28
- self.enc2 = resnet.layer1
29
- self.enc3 = resnet.layer2
30
- self.enc4 = resnet.layer3
31
- self.enc5 = resnet.layer4
32
-
33
- self.bottleneck = nn.Sequential(
34
- nn.Conv2d(512, 512, 3, padding=1), nn.BatchNorm2d(512), nn.ReLU(True),
35
- nn.Conv2d(512, 512, 3, padding=1), nn.BatchNorm2d(512), nn.ReLU(True)
36
- )
37
-
38
- self.up1 = self._make_decoder_block(512, 256)
39
- self.up2 = self._make_decoder_block(512, 128) # 256+256
40
- self.up3 = self._make_decoder_block(256, 64) # 128+128
41
- self.up4 = self._make_decoder_block(128, 32) # 64+64
42
-
43
- self.texture_refine = nn.Sequential(
44
- nn.Conv2d(32, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(True),
45
- nn.Conv2d(32, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(True)
46
- )
47
-
48
- self.final = nn.Sequential(
49
- nn.Conv2d(32, 16, 3, padding=1), nn.ReLU(True),
50
- nn.Conv2d(16, 3, 3, padding=1), nn.Tanh()
51
- )
52
-
53
- def _make_decoder_block(self, in_channels, out_channels):
54
- return nn.Sequential(
55
- nn.ConvTranspose2d(in_channels, out_channels, 4, stride=2, padding=1),
56
- nn.BatchNorm2d(out_channels), nn.ReLU(True),
57
- nn.Conv2d(out_channels, out_channels, 3, padding=1),
58
- nn.BatchNorm2d(out_channels), nn.ReLU(True),
59
- nn.Conv2d(out_channels, out_channels, 3, padding=1),
60
- nn.BatchNorm2d(out_channels), nn.ReLU(True)
61
- )
62
-
63
- def forward(self, img, mask):
64
- x = torch.cat([img, mask], dim=1)
65
- x1 = self.enc1(x)
66
- x2 = self.enc2(x1)
67
- x3 = self.enc3(x2)
68
- x4 = self.enc4(x3)
69
- x5 = self.enc5(x4)
70
-
71
- x = self.bottleneck(x5)
72
-
73
- x = self.up1(x)
74
- x = torch.cat([x, x4], dim=1)
75
- x = self.up2(x)
76
- x = torch.cat([x, x3], dim=1)
77
- x = self.up3(x)
78
- x = torch.cat([x, x2], dim=1)
79
- x = self.up4(x)
80
-
81
- x = self.texture_refine(x)
82
- return self.final(x)
83
-
84
-
85
- print("Loading Inpainting Model...")
86
- model = InpaintingGenerator().to(DEVICE)
87
-
88
- try:
89
- # Set weights_only=False to avoid numpy errors
90
- checkpoint = torch.load(MODEL_PATH, map_location=DEVICE, weights_only=False)
91
-
92
- # Handle DataParallel wrapping
93
- if 'generator' in checkpoint:
94
- state_dict = checkpoint['generator']
95
- else:
96
- state_dict = checkpoint
97
-
98
- new_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
99
- model.load_state_dict(new_state_dict, strict=False)
100
- model.eval()
101
- print("Model loaded successfully!")
102
- except Exception as e:
103
- print(f"Error loading model: {e}")
104
-
105
-
106
- def to_base64(image_array):
107
- img = Image.fromarray(image_array)
108
- buffer = BytesIO()
109
- img.save(buffer, format="PNG")
110
- return base64.b64encode(buffer.getvalue()).decode('utf-8')
111
-
112
- @app.route('/')
113
- def home():
114
- return "Inpainting API is Running!"
115
-
116
- @app.route('/inpaint', methods=['POST'])
117
- def inpaint():
118
- if 'image' not in request.files or 'mask' not in request.files:
119
- return jsonify({'error': 'Please upload both image and mask'}), 400
120
-
121
- try:
122
- img_file = request.files['image']
123
- img_arr = np.frombuffer(img_file.read(), np.uint8)
124
- img_cv = cv2.imdecode(img_arr, cv2.IMREAD_COLOR)
125
- img_cv = cv2.cvtColor(img_cv, cv2.COLOR_BGR2RGB)
126
-
127
- # 2. Read Mask
128
- mask_file = request.files['mask']
129
- mask_arr = np.frombuffer(mask_file.read(), np.uint8)
130
- mask_cv = cv2.imdecode(mask_arr, cv2.IMREAD_GRAYSCALE)
131
-
132
- # 3. Preprocess
133
- img_h, img_w = img_cv.shape[:2]
134
- # Resize to 512x512 for model
135
- img_resized = cv2.resize(img_cv, (512, 512))
136
- mask_resized = cv2.resize(mask_cv, (512, 512))
137
-
138
- # Normalize
139
- img_tensor = (torch.tensor(img_resized).float() / 127.5) - 1.0
140
- img_tensor = img_tensor.permute(2, 0, 1).unsqueeze(0).to(DEVICE)
141
-
142
- mask_tensor = (torch.tensor(mask_resized).float() > 127).float().unsqueeze(0).unsqueeze(0).to(DEVICE)
143
-
144
- # 4. Inference
145
- with torch.no_grad():
146
- output = model(img_tensor, mask_tensor)
147
-
148
- # 5. Post-process
149
- output_np = output.squeeze().permute(1, 2, 0).cpu().numpy()
150
- output_np = (output_np + 1.0) * 127.5
151
- output_np = np.clip(output_np, 0, 255).astype(np.uint8)
152
-
153
- # Resize back to original dimensions
154
- output_final = cv2.resize(output_np, (img_w, img_h))
155
-
156
- return jsonify({'result': f"data:image/png;base64,{to_base64(output_final)}"})
157
-
158
- except Exception as e:
159
- return jsonify({'error': str(e)}), 500
160
-
161
- if __name__ == '__main__':
 
 
 
 
 
 
 
 
162
  app.run(host='0.0.0.0', port=7860)
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torchvision.models as models
5
+ from flask import Flask, request, jsonify
6
+ from flask_cors import CORS
7
+ import numpy as np
8
+ import cv2
9
+ import base64
10
+ from io import BytesIO
11
+ from PIL import Image
12
+
13
+ app = Flask(__name__)
14
+ CORS(app)
15
+
16
+ DEVICE = torch.device('cpu') # Force CPU for HF Free Tier
17
+ MODEL_PATH = "best_model_fixed.pth" # Upload your trained .pth file
18
+
19
+ class InpaintingGenerator(nn.Module):
20
+ def __init__(self, input_channels=4):
21
+ super().__init__()
22
+ resnet = models.resnet34(weights=None)
23
+
24
+ self.enc1 = nn.Sequential(
25
+ nn.Conv2d(input_channels, 64, kernel_size=7, stride=2, padding=3, bias=False),
26
+ resnet.bn1, resnet.relu
27
+ )
28
+ self.enc2 = resnet.layer1
29
+ self.enc3 = resnet.layer2
30
+ self.enc4 = resnet.layer3
31
+ self.enc5 = resnet.layer4
32
+
33
+ self.bottleneck = nn.Sequential(
34
+ nn.Conv2d(512, 512, 3, padding=1), nn.BatchNorm2d(512), nn.ReLU(True),
35
+ nn.Conv2d(512, 512, 3, padding=1), nn.BatchNorm2d(512), nn.ReLU(True)
36
+ )
37
+
38
+ self.up1 = self._make_decoder_block(512, 256)
39
+ self.up2 = self._make_decoder_block(512, 128) # 256+256
40
+ self.up3 = self._make_decoder_block(256, 64) # 128+128
41
+ self.up4 = self._make_decoder_block(128, 32) # 64+64
42
+
43
+ self.texture_refine = nn.Sequential(
44
+ nn.Conv2d(32, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(True),
45
+ nn.Conv2d(32, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(True)
46
+ )
47
+
48
+ self.final = nn.Sequential(
49
+ nn.Conv2d(32, 16, 3, padding=1), nn.ReLU(True),
50
+ nn.Conv2d(16, 3, 3, padding=1), nn.Tanh()
51
+ )
52
+
53
+ def _make_decoder_block(self, in_channels, out_channels):
54
+ return nn.Sequential(
55
+ nn.ConvTranspose2d(in_channels, out_channels, 4, stride=2, padding=1),
56
+ nn.BatchNorm2d(out_channels), nn.ReLU(True),
57
+ nn.Conv2d(out_channels, out_channels, 3, padding=1),
58
+ nn.BatchNorm2d(out_channels), nn.ReLU(True),
59
+ nn.Conv2d(out_channels, out_channels, 3, padding=1),
60
+ nn.BatchNorm2d(out_channels), nn.ReLU(True)
61
+ )
62
+
63
+ def forward(self, img, mask):
64
+ x = torch.cat([img, mask], dim=1)
65
+ x1 = self.enc1(x)
66
+ x2 = self.enc2(x1)
67
+ x3 = self.enc3(x2)
68
+ x4 = self.enc4(x3)
69
+ x5 = self.enc5(x4)
70
+
71
+ x = self.bottleneck(x5)
72
+
73
+ x = self.up1(x)
74
+ x = torch.cat([x, x4], dim=1)
75
+ x = self.up2(x)
76
+ x = torch.cat([x, x3], dim=1)
77
+ x = self.up3(x)
78
+ x = torch.cat([x, x2], dim=1)
79
+ x = self.up4(x)
80
+
81
+ x = self.texture_refine(x)
82
+ return self.final(x)
83
+
84
+
85
+ print("Loading Inpainting Model...")
86
+ model = InpaintingGenerator().to(DEVICE)
87
+
88
+ try:
89
+ # Set weights_only=False to avoid numpy errors
90
+ checkpoint = torch.load(MODEL_PATH, map_location=DEVICE, weights_only=False)
91
+
92
+ # Handle DataParallel wrapping
93
+ if 'generator' in checkpoint:
94
+ state_dict = checkpoint['generator']
95
+ else:
96
+ state_dict = checkpoint
97
+
98
+ new_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
99
+ model.load_state_dict(new_state_dict, strict=False)
100
+ model.eval()
101
+ print("Model loaded successfully!")
102
+ except Exception as e:
103
+ print(f"Error loading model: {e}")
104
+
105
+
106
+ def to_base64(image_array):
107
+ img = Image.fromarray(image_array)
108
+ buffer = BytesIO()
109
+ img.save(buffer, format="PNG")
110
+ return base64.b64encode(buffer.getvalue()).decode('utf-8')
111
+
112
+ @app.route('/')
113
+ def home():
114
+ return "Inpainting API is Running!"
115
+
116
+ @app.route('/inpaint', methods=['POST'])
117
+ def inpaint():
118
+ if 'image' not in request.files or 'mask' not in request.files:
119
+ return jsonify({'error': 'Please upload both image and mask'}), 400
120
+
121
+ try:
122
+ # 1. Read Image
123
+ img_file = request.files['image']
124
+ img_arr = np.frombuffer(img_file.read(), np.uint8)
125
+ img_cv = cv2.imdecode(img_arr, cv2.IMREAD_COLOR)
126
+ img_cv = cv2.cvtColor(img_cv, cv2.COLOR_BGR2RGB)
127
+
128
+ # 2. Read Mask
129
+ mask_file = request.files['mask']
130
+ mask_arr = np.frombuffer(mask_file.read(), np.uint8)
131
+
132
+ # [CRITICAL FIX] Read "unchanged" to preserve the low values (1, 2, 3)
133
+ mask_cv = cv2.imdecode(mask_arr, cv2.IMREAD_UNCHANGED)
134
+
135
+ # If mask is RGB/RGBA, convert to grayscale
136
+ if len(mask_cv.shape) > 2:
137
+ mask_cv = cv2.cvtColor(mask_cv, cv2.COLOR_BGR2GRAY)
138
+
139
+ # 3. Preprocess
140
+ img_h, img_w = img_cv.shape[:2]
141
+ img_resized = cv2.resize(img_cv, (512, 512))
142
+
143
+ # Resize mask carefully (Nearest Neighbor preserves exact class IDs 0,1,2...)
144
+ mask_resized = cv2.resize(mask_cv, (512, 512), interpolation=cv2.INTER_NEAREST)
145
+
146
+ # Normalize Image
147
+ img_tensor = (torch.tensor(img_resized).float() / 127.5) - 1.0
148
+ img_tensor = img_tensor.permute(2, 0, 1).unsqueeze(0).to(DEVICE)
149
+
150
+ # [CRITICAL FIX] Logic change: Check if pixel > 0, NOT > 127
151
+ # This converts your class indices (1, 2, 3...) into a binary 1.0
152
+ mask_tensor = (torch.tensor(mask_resized).float() > 0).float().unsqueeze(0).unsqueeze(0).to(DEVICE)
153
+
154
+ # 4. Inference
155
+ with torch.no_grad():
156
+ output = model(img_tensor, mask_tensor)
157
+
158
+ # 5. Post-process (Same as before)
159
+ output_np = output.squeeze().permute(1, 2, 0).cpu().numpy()
160
+ output_np = (output_np + 1.0) * 127.5
161
+ output_np = np.clip(output_np, 0, 255).astype(np.uint8)
162
+ output_final = cv2.resize(output_np, (img_w, img_h))
163
+
164
+ return jsonify({'result': f"data:image/png;base64,{to_base64(output_final)}"})
165
+
166
+ except Exception as e:
167
+ return jsonify({'error': str(e)}), 500
168
+
169
+ if __name__ == '__main__':
170
  app.run(host='0.0.0.0', port=7860)