Asartb commited on
Commit
4a6bc54
·
verified ·
1 Parent(s): 84aed80

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +137 -0
main.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File, UploadFile, HTTPException, Form
2
+ from fastapi.responses import StreamingResponse
3
+ from PIL import Image
4
+ import torch
5
+ import torch.nn as nn
6
+ import torchvision.transforms as transforms
7
+ import io
8
+
9
+ # Define the neural network layers and models as before
10
+ norm_layer = nn.InstanceNorm2d
11
+
12
+ class ResidualBlock(nn.Module):
13
+ def __init__(self, in_features):
14
+ super(ResidualBlock, self).__init__()
15
+ conv_block = [
16
+ nn.ReflectionPad2d(1),
17
+ nn.Conv2d(in_features, in_features, 3),
18
+ norm_layer(in_features),
19
+ nn.ReLU(inplace=True),
20
+ nn.ReflectionPad2d(1),
21
+ nn.Conv2d(in_features, in_features, 3),
22
+ norm_layer(in_features)
23
+ ]
24
+ self.conv_block = nn.Sequential(*conv_block)
25
+
26
+ def forward(self, x):
27
+ return x + self.conv_block(x)
28
+
29
+ class Generator(nn.Module):
30
+ def __init__(self, input_nc, output_nc, n_residual_blocks=9, sigmoid=True):
31
+ super(Generator, self).__init__()
32
+ model0 = [
33
+ nn.ReflectionPad2d(3),
34
+ nn.Conv2d(input_nc, 64, 7),
35
+ norm_layer(64),
36
+ nn.ReLU(inplace=True)
37
+ ]
38
+ self.model0 = nn.Sequential(*model0)
39
+
40
+ model1 = []
41
+ in_features = 64
42
+ out_features = in_features * 2
43
+ for _ in range(2):
44
+ model1 += [
45
+ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
46
+ norm_layer(out_features),
47
+ nn.ReLU(inplace=True)
48
+ ]
49
+ in_features = out_features
50
+ out_features = in_features * 2
51
+ self.model1 = nn.Sequential(*model1)
52
+
53
+ model2 = []
54
+ for _ in range(n_residual_blocks):
55
+ model2 += [ResidualBlock(in_features)]
56
+ self.model2 = nn.Sequential(*model2)
57
+
58
+ model3 = []
59
+ out_features = in_features // 2
60
+ for _ in range(2):
61
+ model3 += [
62
+ nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
63
+ norm_layer(out_features),
64
+ nn.ReLU(inplace=True)
65
+ ]
66
+ in_features = out_features
67
+ out_features = in_features // 2
68
+ self.model3 = nn.Sequential(*model3)
69
+
70
+ model4 = [
71
+ nn.ReflectionPad2d(3),
72
+ nn.Conv2d(64, output_nc, 7)
73
+ ]
74
+ if sigmoid:
75
+ model4 += [nn.Sigmoid()]
76
+
77
+ self.model4 = nn.Sequential(*model4)
78
+
79
+ def forward(self, x, cond=None):
80
+ out = self.model0(x)
81
+ out = self.model1(out)
82
+ out = self.model2(out)
83
+ out = self.model3(out)
84
+ out = self.model4(out)
85
+ return out
86
+
87
+ # Load the models
88
+ model1 = Generator(3, 1, 3)
89
+ model1.load_state_dict(torch.load('model.pth', map_location=torch.device('cpu')))
90
+ model1.eval()
91
+
92
+ model2 = Generator(3, 1, 3)
93
+ model2.load_state_dict(torch.load('model2.pth', map_location=torch.device('cpu')))
94
+ model2.eval()
95
+
96
+ # Initialize FastAPI
97
+ app = FastAPI()
98
+
99
+ # Endpoint to process the image
100
+ @app.post("/predict/")
101
+ async def process_image(
102
+ file: UploadFile = File(...),
103
+ version: str = Form(...)
104
+ ):
105
+ try:
106
+ # Open the image file
107
+ image = Image.open(file.file)
108
+
109
+ # Define the transformation pipeline
110
+ transform = transforms.Compose([
111
+ transforms.Resize(256, Image.BICUBIC),
112
+ transforms.ToTensor(),
113
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
114
+ ])
115
+
116
+ # Apply the transformation
117
+ input_tensor = transform(image).unsqueeze(0)
118
+
119
+ # Process the image through the model
120
+ with torch.no_grad():
121
+ if version == 'Simple Lines':
122
+ output = model2(input_tensor)
123
+ else:
124
+ output = model1(input_tensor)
125
+
126
+ # Convert the output tensor to an image
127
+ output_img = transforms.ToPILImage()(output.squeeze().cpu().clamp(0, 1))
128
+
129
+ # Save the image to a bytes buffer
130
+ buffer = io.BytesIO()
131
+ output_img.save(buffer, format="JPEG")
132
+ buffer.seek(0)
133
+
134
+ return StreamingResponse(buffer, media_type="image/jpeg")
135
+
136
+ except Exception as e:
137
+ raise HTTPException(status_code=500, detail=str(e))