Ved Gupta
initial commit
77b60c4
raw
history blame
1.58 kB
import warnings
warnings.filterwarnings("ignore")
import os
import sys
import glob
import time
import numpy as np
from PIL import Image
from pathlib import Path
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from skimage.color import rgb2lab, lab2rgb
import torch
from torch import nn, optim
from torchvision import transforms
from torchvision.utils import make_grid
from torch.utils.data import Dataset, DataLoader
from utility import *
from model import *
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_path = "model/ImageColorizationModel.pth"
model = None
if not os.path.exists(model_path) :
print("Model not find")
download_from_drive()
print("Model Downloaded")
else:
model = load_model(model_class=MainModel , file_path=model_path)
print("Model Loaded")
def predict_and_return_image(image):
data = create_lab_tensors(image)
model.net_G.eval()
with torch.no_grad():
model.setup_input(data)
model.forward()
fake_color = model.fake_color.detach()
L = model.L
fake_imgs = lab_to_rgb(L, fake_color)
return fake_imgs[0]
title = "Black&White to Color image"
description = "Transforming Black & White Image in to colored image. Upload a black and white image to see it colorized by our deep learning model."
gr.Interface(
fn=predict_and_return_image,
title=title,
description=description,
inputs=[gr.Image(label="Gray Scale Image")],
outputs=[
gr.Image(label="Predicted Colored Image")
],
).launch(share=True, debug=True)