musk12's picture
Update app.py
3cdac71 verified
# import os
# import torch
# import torchvision.transforms as T
# import torchvision.transforms.functional as TF
# import numpy as np
# from PIL import Image
# from flask import Flask, render_template, request, send_file, abort
# app = Flask(__name__)
# device = "cuda" if torch.cuda.is_available() else "cpu"
# # Load model (assuming UNet is defined in unet.py)
# def load_model():
# try:
# from unet import UNet
# model = UNet().to(device)
# model_path = "unet_car_final.pth"
# if not os.path.exists(model_path):
# raise FileNotFoundError(f"Model file {model_path} not found")
# model.load_state_dict(torch.load(model_path, map_location=device))
# model.eval()
# return model
# except Exception as e:
# print(f"Error loading model: {e}")
# raise
# try:
# model = load_model()
# except Exception as e:
# print(f"Model loading failed: {e}")
# model = None
# # Image transforms
# img_transform = T.Compose([
# T.Resize((256, 256)),
# T.ToTensor(),
# T.Normalize(mean=[0.485, 0.456, 0.406],
# std=[0.229, 0.224, 0.225])
# ])
# TMP_FOLDER = "/tmp"
# os.makedirs(TMP_FOLDER, exist_ok=True)
# # Route to serve files from /tmp
# @app.route('/tmp/<filename>')
# def serve_tmp_file(filename):
# file_path = os.path.join(TMP_FOLDER, filename)
# if os.path.exists(file_path):
# return send_file(file_path)
# else:
# print(f"File not found: {file_path}")
# abort(404)
# @app.route("/", methods=["GET", "POST"])
# def index():
# orig = None
# mask = None
# overlay = None
# error = None
# # Check for existing input image
# img_path = os.path.join(TMP_FOLDER, "input.jpg")
# if os.path.exists(img_path):
# orig = "/tmp/input.jpg"
# print(f"Found existing image: {img_path}")
# if request.method == "POST":
# # Handle image upload
# if "image" in request.files:
# file = request.files["image"]
# if file.filename == "":
# error = "No file selected"
# print(error)
# return render_template("index.html", error=error, orig=orig, mask=mask, overlay=overlay)
# try:
# # Save uploaded image to /tmp
# file.save(img_path)
# print(f"Image saved to: {img_path}")
# orig = "/tmp/input.jpg"
# # Clear previous results in /tmp
# for path in [os.path.join(TMP_FOLDER, "mask.png"), os.path.join(TMP_FOLDER, "overlay.png")]:
# if os.path.exists(path):
# os.remove(path)
# print(f"Removed: {path}")
# except Exception as e:
# error = f"Error uploading image: {str(e)}"
# print(f"Upload error: {e}")
# return render_template("index.html", error=error, orig=orig, mask=mask, overlay=overlay)
# # Handle segmentation
# if "segment" in request.form:
# if not os.path.exists(img_path):
# error = "No image available for segmentation"
# print(f"Segmentation error: Image not found at {img_path}")
# return render_template("index.html", error=error, orig=orig, mask=mask, overlay=overlay)
# try:
# if model is None:
# raise ValueError("Model not loaded")
# image = Image.open(img_path).convert("RGB")
# input_tensor = img_transform(image).unsqueeze(0).to(device)
# # Predict
# with torch.no_grad():
# output = model(input_tensor)
# pred_mask = torch.sigmoid(output)
# pred_mask = (pred_mask > 0.5).float()
# # Resize mask back to original image size
# mask_resized = TF.resize(
# TF.to_pil_image(pred_mask.squeeze().cpu()),
# size=image.size[::-1],
# interpolation=Image.NEAREST
# )
# # Save mask to /tmp
# mask_path = os.path.join(TMP_FOLDER, "mask.png")
# mask_resized.save(mask_path)
# print(f"Mask saved to: {mask_path}")
# # Create overlay
# mask_np = np.array(mask_resized)
# overlay = np.array(image).copy()
# overlay[mask_np > 128] = [255, 0, 0]
# overlay_img = Image.fromarray(overlay)
# overlay_path = os.path.join(TMP_FOLDER, "overlay.png")
# overlay_img.save(overlay_path)
# print(f"Overlay saved to: {overlay_path}")
# mask = "/tmp/mask.png"
# overlay = "/tmp/overlay.png"
# except Exception as e:
# error = f"Error during segmentation: {str(e)}"
# print(f"Segmentation error: {e}")
# return render_template("index.html", error=error, orig=orig, mask=mask, overlay=overlay)
# return render_template("index.html", orig=orig, mask=mask, overlay=overlay, error=error)
# if __name__ == "__main__":
# app.run(debug=True)
import os
import torch
import torchvision.transforms as T
import torchvision.transforms.functional as TF
import numpy as np
from PIL import Image
from flask import Flask, render_template, request, send_file, abort
app = Flask(__name__)
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load model (assuming UNet is defined in unet.py)
def load_model():
try:
from unet import UNet
model = UNet().to(device)
model_path = "unet_car_final.pth"
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model file {model_path} not found")
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
return model
except Exception as e:
print(f"Error loading model: {e}")
raise
try:
model = load_model()
except Exception as e:
print(f"Model loading failed: {e}")
model = None
# Image transforms
img_transform = T.Compose([
T.Resize((256, 256)),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
TMP_FOLDER = "/tmp"
os.makedirs(TMP_FOLDER, exist_ok=True)
# Route to serve files from /tmp
@app.route('/tmp/<filename>')
def serve_tmp_file(filename):
file_path = os.path.join(TMP_FOLDER, filename)
if os.path.exists(file_path):
return send_file(file_path)
else:
print(f"File not found: {file_path}")
abort(404)
@app.route("/", methods=["GET", "POST"])
def index():
orig = None
mask = None
overlay = None
error = None
if request.method == "GET":
# Clear all relevant files in /tmp when a user accesses the root route
for filename in ["input.jpg", "mask.png", "overlay.png"]:
file_path = os.path.join(TMP_FOLDER, filename)
if os.path.exists(file_path):
try:
os.remove(file_path)
print(f"Cleared file: {file_path}")
except Exception as e:
print(f"Error clearing file {file_path}: {e}")
# Check for existing input image (will be None since we cleared /tmp/input.jpg)
img_path = os.path.join(TMP_FOLDER, "input.jpg")
if os.path.exists(img_path):
orig = "/tmp/input.jpg"
print(f"Found existing image: {img_path}")
if request.method == "POST":
# Handle image upload
if "image" in request.files:
file = request.files["image"]
if file.filename == "":
error = "No file selected"
print(error)
return render_template("index.html", error=error, orig=orig, mask=mask, overlay=overlay)
try:
# Save uploaded image to /tmp
file.save(img_path)
print(f"Image saved to: {img_path}")
orig = "/tmp/input.jpg"
# Clear previous results in /tmp
for path in [os.path.join(TMP_FOLDER, "mask.png"), os.path.join(TMP_FOLDER, "overlay.png")]:
if os.path.exists(path):
os.remove(path)
print(f"Removed: {path}")
except Exception as e:
error = f"Error uploading image: {str(e)}"
print(f"Upload error: {e}")
return render_template("index.html", error=error, orig=orig, mask=mask, overlay=overlay)
# Handle segmentation
if "segment" in request.form:
if not os.path.exists(img_path):
error = "No image available for segmentation"
print(f"Segmentation error: Image not found at {img_path}")
return render_template("index.html", error=error, orig=orig, mask=mask, overlay=overlay)
try:
if model is None:
raise ValueError("Model not loaded")
image = Image.open(img_path).convert("RGB")
input_tensor = img_transform(image).unsqueeze(0).to(device)
# Predict
with torch.no_grad():
output = model(input_tensor)
pred_mask = torch.sigmoid(output)
pred_mask = (pred_mask > 0.5).float()
# Resize mask back to original image size
mask_resized = TF.resize(
TF.to_pil_image(pred_mask.squeeze().cpu()),
size=image.size[::-1],
interpolation=Image.NEAREST
)
# Save mask to /tmp
mask_path = os.path.join(TMP_FOLDER, "mask.png")
mask_resized.save(mask_path)
print(f"Mask saved to: {mask_path}")
# Create overlay
mask_np = np.array(mask_resized)
overlay = np.array(image).copy()
overlay[mask_np > 128] = [255, 0, 0]
overlay_img = Image.fromarray(overlay)
overlay_path = os.path.join(TMP_FOLDER, "overlay.png")
overlay_img.save(overlay_path)
print(f"Overlay saved to: {overlay_path}")
mask = "/tmp/mask.png"
overlay = "/tmp/overlay.png"
except Exception as e:
error = f"Error during segmentation: {str(e)}"
print(f"Segmentation error: {e}")
return render_template("index.html", error=error, orig=orig, mask=mask, overlay=overlay)
return render_template("index.html", orig=orig, mask=mask, overlay=overlay, error=error)
if __name__ == "__main__":
app.run(debug=True)