from fastapi import FastAPI, File, UploadFile, HTTPException, Form from fastapi.responses import StreamingResponse from PIL import Image import torch import torch.nn as nn import torchvision.transforms as transforms import io # Define the neural network layers and models as before norm_layer = nn.InstanceNorm2d class ResidualBlock(nn.Module): def __init__(self, in_features): super(ResidualBlock, self).__init__() conv_block = [ nn.ReflectionPad2d(1), nn.Conv2d(in_features, in_features, 3), norm_layer(in_features), nn.ReLU(inplace=True), nn.ReflectionPad2d(1), nn.Conv2d(in_features, in_features, 3), norm_layer(in_features) ] self.conv_block = nn.Sequential(*conv_block) def forward(self, x): return x + self.conv_block(x) class Generator(nn.Module): def __init__(self, input_nc, output_nc, n_residual_blocks=9, sigmoid=True): super(Generator, self).__init__() model0 = [ nn.ReflectionPad2d(3), nn.Conv2d(input_nc, 64, 7), norm_layer(64), nn.ReLU(inplace=True) ] self.model0 = nn.Sequential(*model0) model1 = [] in_features = 64 out_features = in_features * 2 for _ in range(2): model1 += [ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1), norm_layer(out_features), nn.ReLU(inplace=True) ] in_features = out_features out_features = in_features * 2 self.model1 = nn.Sequential(*model1) model2 = [] for _ in range(n_residual_blocks): model2 += [ResidualBlock(in_features)] self.model2 = nn.Sequential(*model2) model3 = [] out_features = in_features // 2 for _ in range(2): model3 += [ nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1), norm_layer(out_features), nn.ReLU(inplace=True) ] in_features = out_features out_features = in_features // 2 self.model3 = nn.Sequential(*model3) model4 = [ nn.ReflectionPad2d(3), nn.Conv2d(64, output_nc, 7) ] if sigmoid: model4 += [nn.Sigmoid()] self.model4 = nn.Sequential(*model4) def forward(self, x, cond=None): out = self.model0(x) out = self.model1(out) out = self.model2(out) out = self.model3(out) out = self.model4(out) return out # Load the models model1 = Generator(3, 1, 3) model1.load_state_dict(torch.load('model.pth', map_location=torch.device('cpu'))) model1.eval() model2 = Generator(3, 1, 3) model2.load_state_dict(torch.load('model2.pth', map_location=torch.device('cpu'))) model2.eval() # Initialize FastAPI app = FastAPI() # Endpoint to process the image @app.post("/predict/") async def process_image( file: UploadFile = File(...), version: str = Form(...) ): try: # Open the image file image = Image.open(file.file) # Define the transformation pipeline transform = transforms.Compose([ transforms.Resize(256, Image.BICUBIC), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) # Apply the transformation input_tensor = transform(image).unsqueeze(0) # Process the image through the model with torch.no_grad(): if version == 'Simple Lines': output = model2(input_tensor) else: output = model1(input_tensor) # Convert the output tensor to an image output_img = transforms.ToPILImage()(output.squeeze().cpu().clamp(0, 1)) # Save the image to a bytes buffer buffer = io.BytesIO() output_img.save(buffer, format="JPEG") buffer.seek(0) return StreamingResponse(buffer, media_type="image/jpeg") except Exception as e: raise HTTPException(status_code=500, detail=str(e))