# import os # import torch # import torchvision.transforms as T # import torchvision.transforms.functional as TF # import numpy as np # from PIL import Image # from flask import Flask, render_template, request, send_file, abort # app = Flask(__name__) # device = "cuda" if torch.cuda.is_available() else "cpu" # # Load model (assuming UNet is defined in unet.py) # def load_model(): # try: # from unet import UNet # model = UNet().to(device) # model_path = "unet_car_final.pth" # if not os.path.exists(model_path): # raise FileNotFoundError(f"Model file {model_path} not found") # model.load_state_dict(torch.load(model_path, map_location=device)) # model.eval() # return model # except Exception as e: # print(f"Error loading model: {e}") # raise # try: # model = load_model() # except Exception as e: # print(f"Model loading failed: {e}") # model = None # # Image transforms # img_transform = T.Compose([ # T.Resize((256, 256)), # T.ToTensor(), # T.Normalize(mean=[0.485, 0.456, 0.406], # std=[0.229, 0.224, 0.225]) # ]) # TMP_FOLDER = "/tmp" # os.makedirs(TMP_FOLDER, exist_ok=True) # # Route to serve files from /tmp # @app.route('/tmp/') # def serve_tmp_file(filename): # file_path = os.path.join(TMP_FOLDER, filename) # if os.path.exists(file_path): # return send_file(file_path) # else: # print(f"File not found: {file_path}") # abort(404) # @app.route("/", methods=["GET", "POST"]) # def index(): # orig = None # mask = None # overlay = None # error = None # # Check for existing input image # img_path = os.path.join(TMP_FOLDER, "input.jpg") # if os.path.exists(img_path): # orig = "/tmp/input.jpg" # print(f"Found existing image: {img_path}") # if request.method == "POST": # # Handle image upload # if "image" in request.files: # file = request.files["image"] # if file.filename == "": # error = "No file selected" # print(error) # return render_template("index.html", error=error, orig=orig, mask=mask, overlay=overlay) # try: # # Save uploaded image to /tmp # file.save(img_path) # print(f"Image saved to: {img_path}") # orig = "/tmp/input.jpg" # # Clear previous results in /tmp # for path in [os.path.join(TMP_FOLDER, "mask.png"), os.path.join(TMP_FOLDER, "overlay.png")]: # if os.path.exists(path): # os.remove(path) # print(f"Removed: {path}") # except Exception as e: # error = f"Error uploading image: {str(e)}" # print(f"Upload error: {e}") # return render_template("index.html", error=error, orig=orig, mask=mask, overlay=overlay) # # Handle segmentation # if "segment" in request.form: # if not os.path.exists(img_path): # error = "No image available for segmentation" # print(f"Segmentation error: Image not found at {img_path}") # return render_template("index.html", error=error, orig=orig, mask=mask, overlay=overlay) # try: # if model is None: # raise ValueError("Model not loaded") # image = Image.open(img_path).convert("RGB") # input_tensor = img_transform(image).unsqueeze(0).to(device) # # Predict # with torch.no_grad(): # output = model(input_tensor) # pred_mask = torch.sigmoid(output) # pred_mask = (pred_mask > 0.5).float() # # Resize mask back to original image size # mask_resized = TF.resize( # TF.to_pil_image(pred_mask.squeeze().cpu()), # size=image.size[::-1], # interpolation=Image.NEAREST # ) # # Save mask to /tmp # mask_path = os.path.join(TMP_FOLDER, "mask.png") # mask_resized.save(mask_path) # print(f"Mask saved to: {mask_path}") # # Create overlay # mask_np = np.array(mask_resized) # overlay = np.array(image).copy() # overlay[mask_np > 128] = [255, 0, 0] # overlay_img = Image.fromarray(overlay) # overlay_path = os.path.join(TMP_FOLDER, "overlay.png") # overlay_img.save(overlay_path) # print(f"Overlay saved to: {overlay_path}") # mask = "/tmp/mask.png" # overlay = "/tmp/overlay.png" # except Exception as e: # error = f"Error during segmentation: {str(e)}" # print(f"Segmentation error: {e}") # return render_template("index.html", error=error, orig=orig, mask=mask, overlay=overlay) # return render_template("index.html", orig=orig, mask=mask, overlay=overlay, error=error) # if __name__ == "__main__": # app.run(debug=True) import os import torch import torchvision.transforms as T import torchvision.transforms.functional as TF import numpy as np from PIL import Image from flask import Flask, render_template, request, send_file, abort app = Flask(__name__) device = "cuda" if torch.cuda.is_available() else "cpu" # Load model (assuming UNet is defined in unet.py) def load_model(): try: from unet import UNet model = UNet().to(device) model_path = "unet_car_final.pth" if not os.path.exists(model_path): raise FileNotFoundError(f"Model file {model_path} not found") model.load_state_dict(torch.load(model_path, map_location=device)) model.eval() return model except Exception as e: print(f"Error loading model: {e}") raise try: model = load_model() except Exception as e: print(f"Model loading failed: {e}") model = None # Image transforms img_transform = T.Compose([ T.Resize((256, 256)), T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) TMP_FOLDER = "/tmp" os.makedirs(TMP_FOLDER, exist_ok=True) # Route to serve files from /tmp @app.route('/tmp/') def serve_tmp_file(filename): file_path = os.path.join(TMP_FOLDER, filename) if os.path.exists(file_path): return send_file(file_path) else: print(f"File not found: {file_path}") abort(404) @app.route("/", methods=["GET", "POST"]) def index(): orig = None mask = None overlay = None error = None if request.method == "GET": # Clear all relevant files in /tmp when a user accesses the root route for filename in ["input.jpg", "mask.png", "overlay.png"]: file_path = os.path.join(TMP_FOLDER, filename) if os.path.exists(file_path): try: os.remove(file_path) print(f"Cleared file: {file_path}") except Exception as e: print(f"Error clearing file {file_path}: {e}") # Check for existing input image (will be None since we cleared /tmp/input.jpg) img_path = os.path.join(TMP_FOLDER, "input.jpg") if os.path.exists(img_path): orig = "/tmp/input.jpg" print(f"Found existing image: {img_path}") if request.method == "POST": # Handle image upload if "image" in request.files: file = request.files["image"] if file.filename == "": error = "No file selected" print(error) return render_template("index.html", error=error, orig=orig, mask=mask, overlay=overlay) try: # Save uploaded image to /tmp file.save(img_path) print(f"Image saved to: {img_path}") orig = "/tmp/input.jpg" # Clear previous results in /tmp for path in [os.path.join(TMP_FOLDER, "mask.png"), os.path.join(TMP_FOLDER, "overlay.png")]: if os.path.exists(path): os.remove(path) print(f"Removed: {path}") except Exception as e: error = f"Error uploading image: {str(e)}" print(f"Upload error: {e}") return render_template("index.html", error=error, orig=orig, mask=mask, overlay=overlay) # Handle segmentation if "segment" in request.form: if not os.path.exists(img_path): error = "No image available for segmentation" print(f"Segmentation error: Image not found at {img_path}") return render_template("index.html", error=error, orig=orig, mask=mask, overlay=overlay) try: if model is None: raise ValueError("Model not loaded") image = Image.open(img_path).convert("RGB") input_tensor = img_transform(image).unsqueeze(0).to(device) # Predict with torch.no_grad(): output = model(input_tensor) pred_mask = torch.sigmoid(output) pred_mask = (pred_mask > 0.5).float() # Resize mask back to original image size mask_resized = TF.resize( TF.to_pil_image(pred_mask.squeeze().cpu()), size=image.size[::-1], interpolation=Image.NEAREST ) # Save mask to /tmp mask_path = os.path.join(TMP_FOLDER, "mask.png") mask_resized.save(mask_path) print(f"Mask saved to: {mask_path}") # Create overlay mask_np = np.array(mask_resized) overlay = np.array(image).copy() overlay[mask_np > 128] = [255, 0, 0] overlay_img = Image.fromarray(overlay) overlay_path = os.path.join(TMP_FOLDER, "overlay.png") overlay_img.save(overlay_path) print(f"Overlay saved to: {overlay_path}") mask = "/tmp/mask.png" overlay = "/tmp/overlay.png" except Exception as e: error = f"Error during segmentation: {str(e)}" print(f"Segmentation error: {e}") return render_template("index.html", error=error, orig=orig, mask=mask, overlay=overlay) return render_template("index.html", orig=orig, mask=mask, overlay=overlay, error=error) if __name__ == "__main__": app.run(debug=True)