File size: 10,956 Bytes
3cdac71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a57ce4a
 
 
 
 
 
570027d
a57ce4a
 
 
 
 
 
 
570027d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a57ce4a
 
 
 
 
 
 
 
 
6a3c00f
 
a57ce4a
570027d
 
 
 
 
 
 
01af49e
570027d
 
a57ce4a
 
6a3c00f
 
 
570027d
6a3c00f
3cdac71
 
 
 
 
 
 
 
 
 
 
 
01af49e
 
 
 
 
a57ce4a
 
 
 
 
01af49e
 
 
a57ce4a
570027d
 
 
01af49e
 
6a3c00f
570027d
 
 
 
01af49e
570027d
 
01af49e
 
a57ce4a
 
 
 
01af49e
 
 
570027d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
01af49e
570027d
 
 
 
 
 
 
 
01af49e
570027d
 
 
 
 
01af49e
 
570027d
 
a57ce4a
 
05fba58
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
# import os
# import torch
# import torchvision.transforms as T
# import torchvision.transforms.functional as TF
# import numpy as np
# from PIL import Image
# from flask import Flask, render_template, request, send_file, abort

# app = Flask(__name__)

# device = "cuda" if torch.cuda.is_available() else "cpu"

# # Load model (assuming UNet is defined in unet.py)
# def load_model():
#     try:
#         from unet import UNet
#         model = UNet().to(device)
#         model_path = "unet_car_final.pth"
#         if not os.path.exists(model_path):
#             raise FileNotFoundError(f"Model file {model_path} not found")
#         model.load_state_dict(torch.load(model_path, map_location=device))
#         model.eval()
#         return model
#     except Exception as e:
#         print(f"Error loading model: {e}")
#         raise

# try:
#     model = load_model()
# except Exception as e:
#     print(f"Model loading failed: {e}")
#     model = None

# # Image transforms
# img_transform = T.Compose([
#     T.Resize((256, 256)),
#     T.ToTensor(),
#     T.Normalize(mean=[0.485, 0.456, 0.406],
#                 std=[0.229, 0.224, 0.225])
# ])

# TMP_FOLDER = "/tmp"
# os.makedirs(TMP_FOLDER, exist_ok=True)

# # Route to serve files from /tmp
# @app.route('/tmp/<filename>')
# def serve_tmp_file(filename):
#     file_path = os.path.join(TMP_FOLDER, filename)
#     if os.path.exists(file_path):
#         return send_file(file_path)
#     else:
#         print(f"File not found: {file_path}")
#         abort(404)

# @app.route("/", methods=["GET", "POST"])
# def index():
#     orig = None
#     mask = None
#     overlay = None
#     error = None
    
#     # Check for existing input image
#     img_path = os.path.join(TMP_FOLDER, "input.jpg")
#     if os.path.exists(img_path):
#         orig = "/tmp/input.jpg"
#         print(f"Found existing image: {img_path}")

#     if request.method == "POST":
#         # Handle image upload
#         if "image" in request.files:
#             file = request.files["image"]
#             if file.filename == "":
#                 error = "No file selected"
#                 print(error)
#                 return render_template("index.html", error=error, orig=orig, mask=mask, overlay=overlay)

#             try:
#                 # Save uploaded image to /tmp
#                 file.save(img_path)
#                 print(f"Image saved to: {img_path}")
#                 orig = "/tmp/input.jpg"

#                 # Clear previous results in /tmp
#                 for path in [os.path.join(TMP_FOLDER, "mask.png"), os.path.join(TMP_FOLDER, "overlay.png")]:
#                     if os.path.exists(path):
#                         os.remove(path)
#                         print(f"Removed: {path}")
#             except Exception as e:
#                 error = f"Error uploading image: {str(e)}"
#                 print(f"Upload error: {e}")
#                 return render_template("index.html", error=error, orig=orig, mask=mask, overlay=overlay)

#         # Handle segmentation
#         if "segment" in request.form:
#             if not os.path.exists(img_path):
#                 error = "No image available for segmentation"
#                 print(f"Segmentation error: Image not found at {img_path}")
#                 return render_template("index.html", error=error, orig=orig, mask=mask, overlay=overlay)

#             try:
#                 if model is None:
#                     raise ValueError("Model not loaded")
                
#                 image = Image.open(img_path).convert("RGB")
#                 input_tensor = img_transform(image).unsqueeze(0).to(device)

#                 # Predict
#                 with torch.no_grad():
#                     output = model(input_tensor)
#                     pred_mask = torch.sigmoid(output)
#                     pred_mask = (pred_mask > 0.5).float()

#                 # Resize mask back to original image size
#                 mask_resized = TF.resize(
#                     TF.to_pil_image(pred_mask.squeeze().cpu()),
#                     size=image.size[::-1],
#                     interpolation=Image.NEAREST
#                 )

#                 # Save mask to /tmp
#                 mask_path = os.path.join(TMP_FOLDER, "mask.png")
#                 mask_resized.save(mask_path)
#                 print(f"Mask saved to: {mask_path}")

#                 # Create overlay
#                 mask_np = np.array(mask_resized)
#                 overlay = np.array(image).copy()
#                 overlay[mask_np > 128] = [255, 0, 0]
#                 overlay_img = Image.fromarray(overlay)
#                 overlay_path = os.path.join(TMP_FOLDER, "overlay.png")
#                 overlay_img.save(overlay_path)
#                 print(f"Overlay saved to: {overlay_path}")

#                 mask = "/tmp/mask.png"
#                 overlay = "/tmp/overlay.png"
#             except Exception as e:
#                 error = f"Error during segmentation: {str(e)}"
#                 print(f"Segmentation error: {e}")
#                 return render_template("index.html", error=error, orig=orig, mask=mask, overlay=overlay)

#     return render_template("index.html", orig=orig, mask=mask, overlay=overlay, error=error)

# if __name__ == "__main__":
#     app.run(debug=True)



import os
import torch
import torchvision.transforms as T
import torchvision.transforms.functional as TF
import numpy as np
from PIL import Image
from flask import Flask, render_template, request, send_file, abort

app = Flask(__name__)

device = "cuda" if torch.cuda.is_available() else "cpu"

# Load model (assuming UNet is defined in unet.py)
def load_model():
    try:
        from unet import UNet
        model = UNet().to(device)
        model_path = "unet_car_final.pth"
        if not os.path.exists(model_path):
            raise FileNotFoundError(f"Model file {model_path} not found")
        model.load_state_dict(torch.load(model_path, map_location=device))
        model.eval()
        return model
    except Exception as e:
        print(f"Error loading model: {e}")
        raise

try:
    model = load_model()
except Exception as e:
    print(f"Model loading failed: {e}")
    model = None

# Image transforms
img_transform = T.Compose([
    T.Resize((256, 256)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225])
])

TMP_FOLDER = "/tmp"
os.makedirs(TMP_FOLDER, exist_ok=True)

# Route to serve files from /tmp
@app.route('/tmp/<filename>')
def serve_tmp_file(filename):
    file_path = os.path.join(TMP_FOLDER, filename)
    if os.path.exists(file_path):
        return send_file(file_path)
    else:
        print(f"File not found: {file_path}")
        abort(404)

@app.route("/", methods=["GET", "POST"])
def index():
    orig = None
    mask = None
    overlay = None
    error = None
    
    if request.method == "GET":
        # Clear all relevant files in /tmp when a user accesses the root route
        for filename in ["input.jpg", "mask.png", "overlay.png"]:
            file_path = os.path.join(TMP_FOLDER, filename)
            if os.path.exists(file_path):
                try:
                    os.remove(file_path)
                    print(f"Cleared file: {file_path}")
                except Exception as e:
                    print(f"Error clearing file {file_path}: {e}")

    # Check for existing input image (will be None since we cleared /tmp/input.jpg)
    img_path = os.path.join(TMP_FOLDER, "input.jpg")
    if os.path.exists(img_path):
        orig = "/tmp/input.jpg"
        print(f"Found existing image: {img_path}")

    if request.method == "POST":
        # Handle image upload
        if "image" in request.files:
            file = request.files["image"]
            if file.filename == "":
                error = "No file selected"
                print(error)
                return render_template("index.html", error=error, orig=orig, mask=mask, overlay=overlay)

            try:
                # Save uploaded image to /tmp
                file.save(img_path)
                print(f"Image saved to: {img_path}")
                orig = "/tmp/input.jpg"

                # Clear previous results in /tmp
                for path in [os.path.join(TMP_FOLDER, "mask.png"), os.path.join(TMP_FOLDER, "overlay.png")]:
                    if os.path.exists(path):
                        os.remove(path)
                        print(f"Removed: {path}")
            except Exception as e:
                error = f"Error uploading image: {str(e)}"
                print(f"Upload error: {e}")
                return render_template("index.html", error=error, orig=orig, mask=mask, overlay=overlay)

        # Handle segmentation
        if "segment" in request.form:
            if not os.path.exists(img_path):
                error = "No image available for segmentation"
                print(f"Segmentation error: Image not found at {img_path}")
                return render_template("index.html", error=error, orig=orig, mask=mask, overlay=overlay)

            try:
                if model is None:
                    raise ValueError("Model not loaded")
                
                image = Image.open(img_path).convert("RGB")
                input_tensor = img_transform(image).unsqueeze(0).to(device)

                # Predict
                with torch.no_grad():
                    output = model(input_tensor)
                    pred_mask = torch.sigmoid(output)
                    pred_mask = (pred_mask > 0.5).float()

                # Resize mask back to original image size
                mask_resized = TF.resize(
                    TF.to_pil_image(pred_mask.squeeze().cpu()),
                    size=image.size[::-1],
                    interpolation=Image.NEAREST
                )

                # Save mask to /tmp
                mask_path = os.path.join(TMP_FOLDER, "mask.png")
                mask_resized.save(mask_path)
                print(f"Mask saved to: {mask_path}")

                # Create overlay
                mask_np = np.array(mask_resized)
                overlay = np.array(image).copy()
                overlay[mask_np > 128] = [255, 0, 0]
                overlay_img = Image.fromarray(overlay)
                overlay_path = os.path.join(TMP_FOLDER, "overlay.png")
                overlay_img.save(overlay_path)
                print(f"Overlay saved to: {overlay_path}")

                mask = "/tmp/mask.png"
                overlay = "/tmp/overlay.png"
            except Exception as e:
                error = f"Error during segmentation: {str(e)}"
                print(f"Segmentation error: {e}")
                return render_template("index.html", error=error, orig=orig, mask=mask, overlay=overlay)

    return render_template("index.html", orig=orig, mask=mask, overlay=overlay, error=error)

if __name__ == "__main__":
    app.run(debug=True)