import os import io import torch import torch.nn as nn from torchvision import transforms from PIL import Image import numpy as np from flask import Flask, request, render_template, jsonify import base64 # =========================== # CONFIGURATION # =========================== device = torch.device("cuda" if torch.cuda.is_available() else "cpu") img_size = 128 # same as used during training model_path = "model/nail_segmentation_unet.pt" app = Flask(__name__) # =========================== # MODEL DEFINITION (MATCHES TRAINING) # =========================== class DoubleConv(nn.Module): def __init__(self, in_ch, out_ch): super().__init__() self.conv = nn.Sequential( nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True), nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True) ) def forward(self, x): return self.conv(x) class UNet(nn.Module): def __init__(self, in_ch=3, out_ch=1): super().__init__() self.dconv_down1 = DoubleConv(in_ch, 32) self.dconv_down2 = DoubleConv(32, 64) self.dconv_down3 = DoubleConv(64, 128) self.dconv_down4 = DoubleConv(128, 256) self.maxpool = nn.MaxPool2d(2) self.up3 = nn.ConvTranspose2d(256, 128, 2, stride=2) self.up2 = nn.ConvTranspose2d(128, 64, 2, stride=2) self.up1 = nn.ConvTranspose2d(64, 32, 2, stride=2) self.dconv_up3 = DoubleConv(256, 128) self.dconv_up2 = DoubleConv(128, 64) self.dconv_up1 = DoubleConv(64, 32) self.conv_last = nn.Conv2d(32, out_ch, 1) def forward(self, x): conv1 = self.dconv_down1(x) x = self.maxpool(conv1) conv2 = self.dconv_down2(x) x = self.maxpool(conv2) conv3 = self.dconv_down3(x) x = self.maxpool(conv3) x = self.dconv_down4(x) x = self.up3(x) x = torch.cat([x, conv3], dim=1) x = self.dconv_up3(x) x = self.up2(x) x = torch.cat([x, conv2], dim=1) x = self.dconv_up2(x) x = self.up1(x) x = torch.cat([x, conv1], dim=1) x = self.dconv_up1(x) x = self.conv_last(x) x = torch.sigmoid(x) return x # =========================== # LOAD TRAINED MODEL # =========================== model = UNet().to(device) model.load_state_dict(torch.load(model_path, map_location=device)) model.eval() # =========================== # IMAGE TRANSFORM # =========================== transform = transforms.Compose([ transforms.Resize((img_size, img_size)), transforms.ToTensor(), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) # match training normalization ]) # =========================== # UTILITY FUNCTION # =========================== def encode_image(pil_img): buffer = io.BytesIO() pil_img.save(buffer, format="PNG") return base64.b64encode(buffer.getvalue()).decode('utf-8') # =========================== # ROUTES # =========================== @app.route("/", methods=["GET"]) def index(): return render_template("index.html") @app.route("/process", methods=["POST"]) def process_image(): if "image" not in request.files: return jsonify({"error": "No file part"}), 400 file = request.files["image"] if file.filename == "": return jsonify({"error": "No selected file"}), 400 try: image_pil = Image.open(file.stream).convert("RGB") input_img_tensor = transform(image_pil).unsqueeze(0).to(device) with torch.no_grad(): pred_mask = model(input_img_tensor)[0] # Convert mask tensor to binary mask mask_np = pred_mask.squeeze().cpu().numpy() mask_binary = (mask_np > 0.5).astype(np.uint8) * 255 # Resize mask to original image size mask_pil = Image.fromarray(mask_binary).resize(image_pil.size, Image.NEAREST) # Encode images for frontend display original_b64 = encode_image(image_pil) mask_b64 = encode_image(mask_pil) return jsonify({ "original_image": original_b64, "mask_image": mask_b64 }) except Exception as e: return jsonify({"error": f"An error occurred: {str(e)}"}), 500 # =========================== # RUN APP # =========================== if __name__ == "__main__": app.run(debug=True)