maripau22 commited on
Commit
06d8d2f
·
verified ·
1 Parent(s): 757d498

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -19
app.py CHANGED
@@ -3,64 +3,68 @@ import torch
3
  from torchvision import transforms
4
  from PIL import Image
5
 
6
-
7
  model_paths = {
8
- "All colors model": "unet_generator.pt",
9
- "20 colors model": "20color_generator.pt"
10
  }
11
 
12
-
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
 
 
15
  transform = transforms.Compose([
16
  transforms.Resize((512, 512)),
17
  transforms.ToTensor(),
18
  ])
19
 
 
20
  def load_model(path):
21
  model = torch.jit.load(path, map_location=device)
22
  model.eval()
23
  return model
24
 
25
-
26
  def colorize(image, selected_model):
27
  """
28
- Converts the image to greyscale, it shows it to the user and then it generates the colored version depending on the model selected.
 
29
  """
30
- #in greyscale
31
  gray = image.convert("L")
32
 
33
- # preprocess
34
  gray_tensor = transform(gray).unsqueeze(0).to(device)
35
 
36
- # model selected
37
  model = load_model(model_paths[selected_model])
38
 
39
- #colors the image
40
  with torch.no_grad():
41
  output = model(gray_tensor)
42
 
 
43
  output = output.squeeze(0).permute(1, 2, 0).clamp(0, 1).cpu().numpy()
44
  output_image = Image.fromarray((output * 255).astype('uint8'))
45
 
46
- return gray, output_image
47
 
 
48
  gr.Interface(
49
  fn=colorize,
50
  inputs=[
51
- gr.Image(type="pil", label="Input Umage"),
52
- gr.Radio(choices=["All colors", "20 colors"], label="Model")
53
  ],
54
  outputs=[
55
- gr.Image(type="pil", label="Black & White"),
56
- gr.Image(type="pil", label="Colored Image")
57
  ],
58
- title="Image colorization",
59
  description=(
60
- "Upload a color image and choose a model to see it colorized from a greyscale version. "
61
  "The system first converts the input image to black and white, then uses a trained deep learning model "
62
- "to generate a colored version. You can experiment with two models: one trained on a full color palette "
63
  "and another limited to just 20 colors."
64
  )
65
-
66
  ).launch()
 
3
  from torchvision import transforms
4
  from PIL import Image
5
 
6
+ # Path to your exported TorchScript models (.pt)
7
  model_paths = {
8
+ "All colors": "unet_generator.pt",
9
+ "20 colors only": "20color_generator.pt"
10
  }
11
 
12
+ # Check if a GPU is available
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
 
15
+ # Image transformations (resize and convert to tensor)
16
  transform = transforms.Compose([
17
  transforms.Resize((512, 512)),
18
  transforms.ToTensor(),
19
  ])
20
 
21
+ # Function to load the selected model
22
  def load_model(path):
23
  model = torch.jit.load(path, map_location=device)
24
  model.eval()
25
  return model
26
 
27
+ # Main colorization function
28
  def colorize(image, selected_model):
29
  """
30
+ Converts the input image to grayscale, displays it,
31
+ and generates the colorized version using the selected model.
32
  """
33
+ # Convert to grayscale
34
  gray = image.convert("L")
35
 
36
+ # Preprocess for model input
37
  gray_tensor = transform(gray).unsqueeze(0).to(device)
38
 
39
+ # Load the selected model
40
  model = load_model(model_paths[selected_model])
41
 
42
+ # Generate the colorized image
43
  with torch.no_grad():
44
  output = model(gray_tensor)
45
 
46
+ # Process output and convert to PIL image
47
  output = output.squeeze(0).permute(1, 2, 0).clamp(0, 1).cpu().numpy()
48
  output_image = Image.fromarray((output * 255).astype('uint8'))
49
 
50
+ return gray, output_image # Return grayscale and colorized images
51
 
52
+ # Create Gradio interface
53
  gr.Interface(
54
  fn=colorize,
55
  inputs=[
56
+ gr.Image(type="pil", label="Input Image"),
57
+ gr.Radio(choices=["All colors", "20 colors only"], label="Model")
58
  ],
59
  outputs=[
60
+ gr.Image(type="pil", label="Grayscale Image"),
61
+ gr.Image(type="pil", label="Colorized Image")
62
  ],
63
+ title="Image Colorization",
64
  description=(
65
+ "Upload a color image and choose a model to see it colorized from a grayscale version. "
66
  "The system first converts the input image to black and white, then uses a trained deep learning model "
67
+ "to generate a colorized version. You can experiment with two models: one trained on a full color palette "
68
  "and another limited to just 20 colors."
69
  )
 
70
  ).launch()