Ashrafb commited on
Commit
6a5e9a4
·
verified ·
1 Parent(s): b69bd26

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +2 -128
main.py CHANGED
@@ -1,130 +1,4 @@
1
- from fastapi import FastAPI, File, UploadFile, Form
2
- from fastapi.responses import FileResponse, StreamingResponse
3
- from fastapi.staticfiles import StaticFiles
4
- import io
5
- import torch
6
- import torch.nn as nn
7
- from PIL import Image
8
- import torchvision.transforms as transforms
9
 
10
- app = FastAPI()
11
-
12
- norm_layer = nn.InstanceNorm2d
13
-
14
- class ResidualBlock(nn.Module):
15
- def __init__(self, in_features):
16
- super(ResidualBlock, self).__init__()
17
-
18
- conv_block = [ nn.ReflectionPad2d(1),
19
- nn.Conv2d(in_features, in_features, 3),
20
- norm_layer(in_features),
21
- nn.ReLU(inplace=True),
22
- nn.ReflectionPad2d(1),
23
- nn.Conv2d(in_features, in_features, 3),
24
- norm_layer(in_features)
25
- ]
26
-
27
- self.conv_block = nn.Sequential(*conv_block)
28
-
29
- def forward(self, x):
30
- return x + self.conv_block(x)
31
-
32
-
33
- class Generator(nn.Module):
34
- def __init__(self, input_nc, output_nc, n_residual_blocks=9, sigmoid=True):
35
- super(Generator, self).__init__()
36
-
37
- # Initial convolution block
38
- model0 = [ nn.ReflectionPad2d(3),
39
- nn.Conv2d(input_nc, 64, 7),
40
- norm_layer(64),
41
- nn.ReLU(inplace=True) ]
42
- self.model0 = nn.Sequential(*model0)
43
-
44
- # Downsampling
45
- model1 = []
46
- in_features = 64
47
- out_features = in_features*2
48
- for _ in range(2):
49
- model1 += [ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
50
- norm_layer(out_features),
51
- nn.ReLU(inplace=True) ]
52
- in_features = out_features
53
- out_features = in_features*2
54
- self.model1 = nn.Sequential(*model1)
55
-
56
- model2 = []
57
- # Residual blocks
58
- for _ in range(n_residual_blocks):
59
- model2 += [ResidualBlock(in_features)]
60
- self.model2 = nn.Sequential(*model2)
61
-
62
- # Upsampling
63
- model3 = []
64
- out_features = in_features//2
65
- for _ in range(2):
66
- model3 += [ nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
67
- norm_layer(out_features),
68
- nn.ReLU(inplace=True) ]
69
- in_features = out_features
70
- out_features = in_features//2
71
- self.model3 = nn.Sequential(*model3)
72
-
73
- # Output layer
74
- model4 = [ nn.ReflectionPad2d(3),
75
- nn.Conv2d(64, output_nc, 7)]
76
- if sigmoid:
77
- model4 += [nn.Sigmoid()]
78
-
79
- self.model4 = nn.Sequential(*model4)
80
-
81
- def forward(self, x, cond=None):
82
- out = self.model0(x)
83
- out = self.model1(out)
84
- out = self.model2(out)
85
- out = self.model3(out)
86
- out = self.model4(out)
87
-
88
- return out
89
-
90
- model1 = Generator(3, 1, 3)
91
- model1.load_state_dict(torch.load('model.pth', map_location=torch.device('cpu')))
92
- model1.eval()
93
-
94
- model2 = Generator(3, 1, 3)
95
- model2.load_state_dict(torch.load('model2.pth', map_location=torch.device('cpu')))
96
- model2.eval()
97
-
98
- transform = transforms.Compose([transforms.Resize(256, Image.BICUBIC), transforms.ToTensor()])
99
-
100
- from fastapi.middleware.cors import CORSMiddleware
101
-
102
- app.add_middleware(
103
- CORSMiddleware,
104
- allow_origins=["*"], # Adjust as needed, '*' allows requests from any origin
105
- allow_credentials=True,
106
- allow_methods=["*"],
107
- allow_headers=["*"],
108
- )
109
-
110
- @app.post("/predict")
111
- async def predict(file: UploadFile = File(...), version: str = Form(...)):
112
- contents = await file.read()
113
- input_img = Image.open(io.BytesIO(contents))
114
- transform = transforms.Compose([transforms.Resize(256, Image.BICUBIC), transforms.ToTensor()])
115
- input_img = transform(input_img)
116
- input_img = torch.unsqueeze(input_img, 0)
117
-
118
- drawing = 0
119
- with torch.no_grad():
120
- if version == 'Simple Lines':
121
- drawing = model2(input_img)[0].detach()
122
- else:
123
- drawing = model1(input_img)[0].detach()
124
-
125
- drawing_pil = transforms.ToPILImage()(drawing)
126
- img_byte_array = io.BytesIO()
127
- drawing_pil.save(img_byte_array, format="PNG")
128
- img_byte_array.seek(0) # Reset file pointer to start of the file
129
- return StreamingResponse(io.BytesIO(img_byte_array.getvalue()), media_type="image/png")
130
 
 
 
1
+ import os
 
 
 
 
 
 
 
2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
+ exec(os.environ.get('CODE'))