Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,3 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import fastai
|
| 2 |
from fastai.vision import *
|
| 3 |
from fastai.utils.mem import *
|
|
@@ -21,6 +27,8 @@ from PIL import Image
|
|
| 21 |
from io import BytesIO
|
| 22 |
import torchvision.transforms as T
|
| 23 |
|
|
|
|
|
|
|
| 24 |
class FeatureLoss(nn.Module):
|
| 25 |
def __init__(self, m_feat, layer_ids, layer_wgts):
|
| 26 |
super().__init__()
|
|
@@ -94,10 +102,33 @@ def predict(img):
|
|
| 94 |
|
| 95 |
return res, output_file
|
| 96 |
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI, UploadFile, File
|
| 2 |
+
from fastapi.responses import FileResponse
|
| 3 |
+
from fastapi.responses import HTMLResponse, FileResponse
|
| 4 |
+
from fastapi.staticfiles import StaticFiles
|
| 5 |
+
from typing import Tuple
|
| 6 |
+
import cv2
|
| 7 |
import fastai
|
| 8 |
from fastai.vision import *
|
| 9 |
from fastai.utils.mem import *
|
|
|
|
| 27 |
from io import BytesIO
|
| 28 |
import torchvision.transforms as T
|
| 29 |
|
| 30 |
+
app = FastAPI()
|
| 31 |
+
|
| 32 |
class FeatureLoss(nn.Module):
|
| 33 |
def __init__(self, m_feat, layer_ids, layer_wgts):
|
| 34 |
super().__init__()
|
|
|
|
| 102 |
|
| 103 |
return res, output_file
|
| 104 |
|
| 105 |
+
@app.post("/predict/")
|
| 106 |
+
async def predict(file: UploadFile = File(...)) -> Tuple[str, bytes]:
|
| 107 |
+
contents = await file.read()
|
| 108 |
+
img = cv2.imdecode(np.fromstring(contents, np.uint8), cv2.IMREAD_COLOR)
|
| 109 |
+
img = PIL.Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
|
| 110 |
+
img = add_margin(img, 250, 250, 250, 250, (255, 255, 255))
|
| 111 |
+
img = np.array(img)
|
| 112 |
+
|
| 113 |
+
h, w = img.shape[:-1]
|
| 114 |
+
cv2.imwrite("test.jpg", img)
|
| 115 |
+
img_test = open_image("test.jpg")
|
| 116 |
+
|
| 117 |
+
p,img_hr,b = learn.predict(img_test)
|
| 118 |
+
|
| 119 |
+
res = (img_hr / img_hr.max()).numpy()
|
| 120 |
+
res = res[0] # take only first channel as result
|
| 121 |
+
res = cv2.resize(res, (w,h))
|
| 122 |
+
|
| 123 |
+
output_file = get_filename()
|
| 124 |
+
|
| 125 |
+
cv2.imwrite(output_file, (res * 255).astype(np.uint8), [cv2.IMWRITE_JPEG_QUALITY, 50])
|
| 126 |
+
|
| 127 |
+
return output_file, res.tobytes()
|
| 128 |
+
|
| 129 |
+
app.mount("/", StaticFiles(directory="static", html=True), name="static")
|
| 130 |
+
|
| 131 |
+
@app.get("/")
|
| 132 |
+
def index() -> FileResponse:
|
| 133 |
+
return FileResponse(path="/app/static/index.html", media_type="text/html")
|
| 134 |
+
|