from fastapi import FastAPI, Form, Depends, Request, File, UploadFile from fastapi.encoders import jsonable_encoder from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware import segmentation_models_pytorch as smp import torch import numpy as np import cv2 import os from torch.utils.data import Dataset, DataLoader from PIL import Image from io import BytesIO import traceback import base64 DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' model = smp.PAN(encoder_name="resnext50_32x4d", in_channels=3, classes=1) model.to(DEVICE).load_state_dict(torch.load("./model/pan_resnext50_32x4d_adam_lr001_batch16_epoch_50.ckpt", map_location=DEVICE)) app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], # Replace with the list of allowed origins for production allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) class CustomDataset(Dataset): def __init__(self, data, transform=None): self.data = data self.transform = transform def __len__(self): return len(self.data) def __getitem__(self, idx): sample = { 'image': self.data[idx], } if self.transform: sample = self.transform(sample) return sample def combine_images(original_image_np, label_image_np): # Convert label image to grayscale if it's not already if len(label_image_np.shape) > 2: label_image_np = np.mean(label_image_np, axis=2, dtype=np.uint8) # Create a mask where label_image is white (255) mask = label_image_np == 255 # Create an output array initially filled with zeros combined_image_np = np.zeros_like(original_image_np) # Assign original pixels where mask is True (white) combined_image_np[mask] = original_image_np[mask] return combined_image_np @app.get("/") async def root(): return {"message": "Hello World"} @app.post("/segmentation") async def segmentation(file: UploadFile = File(...)): contents = await file.read() image_dataset = [] for file in os.listdir("./images"): image_dataset.append(cv2.resize(cv2.imread('./images/' + file), (160, 544))) image = Image.open(BytesIO(contents)) open_cv_image = np.array(image) open_cv_image = cv2.resize(open_cv_image, (160, 544)) print(type(image_dataset)) image_dataset.insert(0, open_cv_image) image_dataset = np.transpose(image_dataset, (0, 3, 1, 2)) dataset = CustomDataset(image_dataset) dataloader = DataLoader(dataset, batch_size=16, shuffle=False, num_workers=0) try: with torch.no_grad(): for batch in dataloader: temp_image = batch['image'].to(DEVICE).float() print(temp_image.shape) output = model(temp_image) output[0] = (output[0] > 0.5) output = output[0].squeeze().cpu().numpy() output = output * 255 output = output.astype(np.uint8) combined_image_np = combine_images(open_cv_image, output) # combined_image_np = cv2.cvtColor(combined_image_np, cv2.COLOR_BGR2RGB) combined_image_np = Image.fromarray(combined_image_np) buffered = BytesIO() combined_image_np.save(buffered, format="JPEG") img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") except Exception as e: error_message = traceback.format_exc() return JSONResponse(status_code=500, content={"error": str(e), "traceback": error_message}) else: return JSONResponse(status_code=200, content={"result": 'good', "image": img_str}) @app.post("/predict") async def predict(): return None