Ashrafb commited on
Commit
e7a579e
·
verified ·
1 Parent(s): eb260ff

Rename app.py to main.py

Browse files
Files changed (2) hide show
  1. app.py +0 -73
  2. main.py +34 -0
app.py DELETED
@@ -1,73 +0,0 @@
1
- import gradio as gr
2
- import PIL
3
- import cv2
4
- import numpy as np
5
- from src.deoldify import device
6
- from src.deoldify.device_id import DeviceId
7
- from src.deoldify.visualize import *
8
- from src.app_utils import get_model_bin
9
-
10
- device.set(device=DeviceId.CPU)
11
-
12
- def load_model(model_dir, option):
13
- if option.lower() == 'artistic':
14
- model_url = 'https://data.deepai.org/deoldify/ColorizeArtistic_gen.pth'
15
- get_model_bin(model_url, os.path.join(model_dir, "ColorizeArtistic_gen.pth"))
16
- colorizer = get_image_colorizer(artistic=True)
17
- elif option.lower() == 'stable':
18
- model_url = "https://www.dropbox.com/s/usf7uifrctqw9rl/ColorizeStable_gen.pth?dl=0"
19
- get_model_bin(model_url, os.path.join(model_dir, "ColorizeStable_gen.pth"))
20
- colorizer = get_image_colorizer(artistic=False)
21
-
22
- return colorizer
23
-
24
- def resize_img(input_img, max_size):
25
- img = input_img.copy()
26
- img_height, img_width = img.shape[0], img.shape[1]
27
-
28
- if max(img_height, img_width) > max_size:
29
- if img_height > img_width:
30
- new_width = img_width * (max_size / img_height)
31
- new_height = max_size
32
- resized_img = cv2.resize(img, (int(new_width), int(new_height)))
33
- return resized_img
34
- elif img_height <= img_width:
35
- new_width = img_height * (max_size / img_width)
36
- new_height = max_size
37
- resized_img = cv2.resize(img, (int(new_width), int(new_height)))
38
- return resized_img
39
-
40
- return img
41
-
42
- def colorize_image(input_image, colorizer, img_size=800):
43
- pil_img = input_image.convert("RGB")
44
- img_rgb = np.array(pil_img)
45
- resized_img_rgb = resize_img(img_rgb, img_size)
46
- resized_pil_img = PIL.Image.fromarray(resized_img_rgb)
47
- output_pil_img = colorizer.plot_transformed_pil_image(resized_pil_img, render_factor=35, compare=False)
48
-
49
- return output_pil_img
50
-
51
- def app(input_image, model='stable'):
52
- # Load models
53
- colorizer = load_model('models/', model)
54
-
55
- # Colorize the image
56
- output_image = colorize_image(input_image, colorizer)
57
-
58
- return output_image
59
-
60
-
61
-
62
- title = "<span style='color: #191970;'>Aiconvert.online</span>"
63
-
64
- gr.Interface(
65
- app,
66
- inputs=[gr.Image(type="pil", label="Input"), gr.Dropdown(["Artistic", "Stable"], label="Model")],
67
- outputs=gr.Image(type="pil", label="Output", show_share_button=False),
68
- title=title,
69
- css="footer{display:none !important;}",
70
- theme=gr.themes.Base(),
71
- enable_queue=True,
72
- allow_flagging=False
73
- ).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Based on: https://github.com/jantic/DeOldify
2
+ import os, re, time
3
+
4
+ os.environ["TORCH_HOME"] = os.path.join(os.getcwd(), ".cache")
5
+ os.environ["XDG_CACHE_HOME"] = os.path.join(os.getcwd(), ".cache")
6
+
7
+ fastapi import FastAPI, File, UploadFile,Form
8
+ from fastapi.responses import FileResponse, StreamingResponse
9
+ from fastapi.staticfiles import StaticFiles
10
+ from src.deoldify import device
11
+ from src.deoldify.device_id import DeviceId
12
+ from src.app_utils import get_model_bin
13
+ from colorize_image_function import colorize_image
14
+
15
+ app = FastAPI()
16
+
17
+ device.set(device=DeviceId.CPU)
18
+ model_dir = 'models/'
19
+ colorizer = load_model(model_dir, "Artistic")
20
+
21
+ @app.post("/upload/")
22
+ async def upload_file(file: UploadFile = File(...)):
23
+ contents = await file.read()
24
+ img_input = PIL.Image.open(BytesIO(contents)).convert("RGB")
25
+ img_output = colorize_image(img_input)
26
+ img_output_bytes = io.BytesIO()
27
+ img_output.save(img_output_bytes, format="JPEG")
28
+ return img_output_bytes.getvalue()
29
+
30
+ app.mount("/", StaticFiles(directory="static", html=True), name="static")
31
+
32
+ @app.get("/")
33
+ def index() -> FileResponse:
34
+ return FileResponse(path="/app/static/index.html", media_type="text/html")