Spaces:
Sleeping
Sleeping
| # import os | |
| # import torch | |
| # import torchvision.transforms as T | |
| # import torchvision.transforms.functional as TF | |
| # import numpy as np | |
| # from PIL import Image | |
| # from flask import Flask, render_template, request, send_file, abort | |
| # app = Flask(__name__) | |
| # device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # # Load model (assuming UNet is defined in unet.py) | |
| # def load_model(): | |
| # try: | |
| # from unet import UNet | |
| # model = UNet().to(device) | |
| # model_path = "unet_car_final.pth" | |
| # if not os.path.exists(model_path): | |
| # raise FileNotFoundError(f"Model file {model_path} not found") | |
| # model.load_state_dict(torch.load(model_path, map_location=device)) | |
| # model.eval() | |
| # return model | |
| # except Exception as e: | |
| # print(f"Error loading model: {e}") | |
| # raise | |
| # try: | |
| # model = load_model() | |
| # except Exception as e: | |
| # print(f"Model loading failed: {e}") | |
| # model = None | |
| # # Image transforms | |
| # img_transform = T.Compose([ | |
| # T.Resize((256, 256)), | |
| # T.ToTensor(), | |
| # T.Normalize(mean=[0.485, 0.456, 0.406], | |
| # std=[0.229, 0.224, 0.225]) | |
| # ]) | |
| # TMP_FOLDER = "/tmp" | |
| # os.makedirs(TMP_FOLDER, exist_ok=True) | |
| # # Route to serve files from /tmp | |
| # @app.route('/tmp/<filename>') | |
| # def serve_tmp_file(filename): | |
| # file_path = os.path.join(TMP_FOLDER, filename) | |
| # if os.path.exists(file_path): | |
| # return send_file(file_path) | |
| # else: | |
| # print(f"File not found: {file_path}") | |
| # abort(404) | |
| # @app.route("/", methods=["GET", "POST"]) | |
| # def index(): | |
| # orig = None | |
| # mask = None | |
| # overlay = None | |
| # error = None | |
| # # Check for existing input image | |
| # img_path = os.path.join(TMP_FOLDER, "input.jpg") | |
| # if os.path.exists(img_path): | |
| # orig = "/tmp/input.jpg" | |
| # print(f"Found existing image: {img_path}") | |
| # if request.method == "POST": | |
| # # Handle image upload | |
| # if "image" in request.files: | |
| # file = request.files["image"] | |
| # if file.filename == "": | |
| # error = "No file selected" | |
| # print(error) | |
| # return render_template("index.html", error=error, orig=orig, mask=mask, overlay=overlay) | |
| # try: | |
| # # Save uploaded image to /tmp | |
| # file.save(img_path) | |
| # print(f"Image saved to: {img_path}") | |
| # orig = "/tmp/input.jpg" | |
| # # Clear previous results in /tmp | |
| # for path in [os.path.join(TMP_FOLDER, "mask.png"), os.path.join(TMP_FOLDER, "overlay.png")]: | |
| # if os.path.exists(path): | |
| # os.remove(path) | |
| # print(f"Removed: {path}") | |
| # except Exception as e: | |
| # error = f"Error uploading image: {str(e)}" | |
| # print(f"Upload error: {e}") | |
| # return render_template("index.html", error=error, orig=orig, mask=mask, overlay=overlay) | |
| # # Handle segmentation | |
| # if "segment" in request.form: | |
| # if not os.path.exists(img_path): | |
| # error = "No image available for segmentation" | |
| # print(f"Segmentation error: Image not found at {img_path}") | |
| # return render_template("index.html", error=error, orig=orig, mask=mask, overlay=overlay) | |
| # try: | |
| # if model is None: | |
| # raise ValueError("Model not loaded") | |
| # image = Image.open(img_path).convert("RGB") | |
| # input_tensor = img_transform(image).unsqueeze(0).to(device) | |
| # # Predict | |
| # with torch.no_grad(): | |
| # output = model(input_tensor) | |
| # pred_mask = torch.sigmoid(output) | |
| # pred_mask = (pred_mask > 0.5).float() | |
| # # Resize mask back to original image size | |
| # mask_resized = TF.resize( | |
| # TF.to_pil_image(pred_mask.squeeze().cpu()), | |
| # size=image.size[::-1], | |
| # interpolation=Image.NEAREST | |
| # ) | |
| # # Save mask to /tmp | |
| # mask_path = os.path.join(TMP_FOLDER, "mask.png") | |
| # mask_resized.save(mask_path) | |
| # print(f"Mask saved to: {mask_path}") | |
| # # Create overlay | |
| # mask_np = np.array(mask_resized) | |
| # overlay = np.array(image).copy() | |
| # overlay[mask_np > 128] = [255, 0, 0] | |
| # overlay_img = Image.fromarray(overlay) | |
| # overlay_path = os.path.join(TMP_FOLDER, "overlay.png") | |
| # overlay_img.save(overlay_path) | |
| # print(f"Overlay saved to: {overlay_path}") | |
| # mask = "/tmp/mask.png" | |
| # overlay = "/tmp/overlay.png" | |
| # except Exception as e: | |
| # error = f"Error during segmentation: {str(e)}" | |
| # print(f"Segmentation error: {e}") | |
| # return render_template("index.html", error=error, orig=orig, mask=mask, overlay=overlay) | |
| # return render_template("index.html", orig=orig, mask=mask, overlay=overlay, error=error) | |
| # if __name__ == "__main__": | |
| # app.run(debug=True) | |
| import os | |
| import torch | |
| import torchvision.transforms as T | |
| import torchvision.transforms.functional as TF | |
| import numpy as np | |
| from PIL import Image | |
| from flask import Flask, render_template, request, send_file, abort | |
| app = Flask(__name__) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Load model (assuming UNet is defined in unet.py) | |
| def load_model(): | |
| try: | |
| from unet import UNet | |
| model = UNet().to(device) | |
| model_path = "unet_car_final.pth" | |
| if not os.path.exists(model_path): | |
| raise FileNotFoundError(f"Model file {model_path} not found") | |
| model.load_state_dict(torch.load(model_path, map_location=device)) | |
| model.eval() | |
| return model | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| raise | |
| try: | |
| model = load_model() | |
| except Exception as e: | |
| print(f"Model loading failed: {e}") | |
| model = None | |
| # Image transforms | |
| img_transform = T.Compose([ | |
| T.Resize((256, 256)), | |
| T.ToTensor(), | |
| T.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]) | |
| ]) | |
| TMP_FOLDER = "/tmp" | |
| os.makedirs(TMP_FOLDER, exist_ok=True) | |
| # Route to serve files from /tmp | |
| def serve_tmp_file(filename): | |
| file_path = os.path.join(TMP_FOLDER, filename) | |
| if os.path.exists(file_path): | |
| return send_file(file_path) | |
| else: | |
| print(f"File not found: {file_path}") | |
| abort(404) | |
| def index(): | |
| orig = None | |
| mask = None | |
| overlay = None | |
| error = None | |
| if request.method == "GET": | |
| # Clear all relevant files in /tmp when a user accesses the root route | |
| for filename in ["input.jpg", "mask.png", "overlay.png"]: | |
| file_path = os.path.join(TMP_FOLDER, filename) | |
| if os.path.exists(file_path): | |
| try: | |
| os.remove(file_path) | |
| print(f"Cleared file: {file_path}") | |
| except Exception as e: | |
| print(f"Error clearing file {file_path}: {e}") | |
| # Check for existing input image (will be None since we cleared /tmp/input.jpg) | |
| img_path = os.path.join(TMP_FOLDER, "input.jpg") | |
| if os.path.exists(img_path): | |
| orig = "/tmp/input.jpg" | |
| print(f"Found existing image: {img_path}") | |
| if request.method == "POST": | |
| # Handle image upload | |
| if "image" in request.files: | |
| file = request.files["image"] | |
| if file.filename == "": | |
| error = "No file selected" | |
| print(error) | |
| return render_template("index.html", error=error, orig=orig, mask=mask, overlay=overlay) | |
| try: | |
| # Save uploaded image to /tmp | |
| file.save(img_path) | |
| print(f"Image saved to: {img_path}") | |
| orig = "/tmp/input.jpg" | |
| # Clear previous results in /tmp | |
| for path in [os.path.join(TMP_FOLDER, "mask.png"), os.path.join(TMP_FOLDER, "overlay.png")]: | |
| if os.path.exists(path): | |
| os.remove(path) | |
| print(f"Removed: {path}") | |
| except Exception as e: | |
| error = f"Error uploading image: {str(e)}" | |
| print(f"Upload error: {e}") | |
| return render_template("index.html", error=error, orig=orig, mask=mask, overlay=overlay) | |
| # Handle segmentation | |
| if "segment" in request.form: | |
| if not os.path.exists(img_path): | |
| error = "No image available for segmentation" | |
| print(f"Segmentation error: Image not found at {img_path}") | |
| return render_template("index.html", error=error, orig=orig, mask=mask, overlay=overlay) | |
| try: | |
| if model is None: | |
| raise ValueError("Model not loaded") | |
| image = Image.open(img_path).convert("RGB") | |
| input_tensor = img_transform(image).unsqueeze(0).to(device) | |
| # Predict | |
| with torch.no_grad(): | |
| output = model(input_tensor) | |
| pred_mask = torch.sigmoid(output) | |
| pred_mask = (pred_mask > 0.5).float() | |
| # Resize mask back to original image size | |
| mask_resized = TF.resize( | |
| TF.to_pil_image(pred_mask.squeeze().cpu()), | |
| size=image.size[::-1], | |
| interpolation=Image.NEAREST | |
| ) | |
| # Save mask to /tmp | |
| mask_path = os.path.join(TMP_FOLDER, "mask.png") | |
| mask_resized.save(mask_path) | |
| print(f"Mask saved to: {mask_path}") | |
| # Create overlay | |
| mask_np = np.array(mask_resized) | |
| overlay = np.array(image).copy() | |
| overlay[mask_np > 128] = [255, 0, 0] | |
| overlay_img = Image.fromarray(overlay) | |
| overlay_path = os.path.join(TMP_FOLDER, "overlay.png") | |
| overlay_img.save(overlay_path) | |
| print(f"Overlay saved to: {overlay_path}") | |
| mask = "/tmp/mask.png" | |
| overlay = "/tmp/overlay.png" | |
| except Exception as e: | |
| error = f"Error during segmentation: {str(e)}" | |
| print(f"Segmentation error: {e}") | |
| return render_template("index.html", error=error, orig=orig, mask=mask, overlay=overlay) | |
| return render_template("index.html", orig=orig, mask=mask, overlay=overlay, error=error) | |
| if __name__ == "__main__": | |
| app.run(debug=True) |