Lld / main.py
Asartb's picture
Create main.py
4a6bc54 verified
from fastapi import FastAPI, File, UploadFile, HTTPException, Form
from fastapi.responses import StreamingResponse
from PIL import Image
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import io
# Define the neural network layers and models as before
norm_layer = nn.InstanceNorm2d
class ResidualBlock(nn.Module):
def __init__(self, in_features):
super(ResidualBlock, self).__init__()
conv_block = [
nn.ReflectionPad2d(1),
nn.Conv2d(in_features, in_features, 3),
norm_layer(in_features),
nn.ReLU(inplace=True),
nn.ReflectionPad2d(1),
nn.Conv2d(in_features, in_features, 3),
norm_layer(in_features)
]
self.conv_block = nn.Sequential(*conv_block)
def forward(self, x):
return x + self.conv_block(x)
class Generator(nn.Module):
def __init__(self, input_nc, output_nc, n_residual_blocks=9, sigmoid=True):
super(Generator, self).__init__()
model0 = [
nn.ReflectionPad2d(3),
nn.Conv2d(input_nc, 64, 7),
norm_layer(64),
nn.ReLU(inplace=True)
]
self.model0 = nn.Sequential(*model0)
model1 = []
in_features = 64
out_features = in_features * 2
for _ in range(2):
model1 += [
nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
norm_layer(out_features),
nn.ReLU(inplace=True)
]
in_features = out_features
out_features = in_features * 2
self.model1 = nn.Sequential(*model1)
model2 = []
for _ in range(n_residual_blocks):
model2 += [ResidualBlock(in_features)]
self.model2 = nn.Sequential(*model2)
model3 = []
out_features = in_features // 2
for _ in range(2):
model3 += [
nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
norm_layer(out_features),
nn.ReLU(inplace=True)
]
in_features = out_features
out_features = in_features // 2
self.model3 = nn.Sequential(*model3)
model4 = [
nn.ReflectionPad2d(3),
nn.Conv2d(64, output_nc, 7)
]
if sigmoid:
model4 += [nn.Sigmoid()]
self.model4 = nn.Sequential(*model4)
def forward(self, x, cond=None):
out = self.model0(x)
out = self.model1(out)
out = self.model2(out)
out = self.model3(out)
out = self.model4(out)
return out
# Load the models
model1 = Generator(3, 1, 3)
model1.load_state_dict(torch.load('model.pth', map_location=torch.device('cpu')))
model1.eval()
model2 = Generator(3, 1, 3)
model2.load_state_dict(torch.load('model2.pth', map_location=torch.device('cpu')))
model2.eval()
# Initialize FastAPI
app = FastAPI()
# Endpoint to process the image
@app.post("/predict/")
async def process_image(
file: UploadFile = File(...),
version: str = Form(...)
):
try:
# Open the image file
image = Image.open(file.file)
# Define the transformation pipeline
transform = transforms.Compose([
transforms.Resize(256, Image.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# Apply the transformation
input_tensor = transform(image).unsqueeze(0)
# Process the image through the model
with torch.no_grad():
if version == 'Simple Lines':
output = model2(input_tensor)
else:
output = model1(input_tensor)
# Convert the output tensor to an image
output_img = transforms.ToPILImage()(output.squeeze().cpu().clamp(0, 1))
# Save the image to a bytes buffer
buffer = io.BytesIO()
output_img.save(buffer, format="JPEG")
buffer.seek(0)
return StreamingResponse(buffer, media_type="image/jpeg")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))