degbo commited on
Commit
222734d
·
verified ·
1 Parent(s): 053509f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +167 -4
app.py CHANGED
@@ -1,7 +1,170 @@
 
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
1
+ import os
2
+ import sys
3
+ os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
4
+ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
5
+
6
+
7
  import gradio as gr
8
+ import numpy as np
9
+ import torch
10
+ from marigold import MarigoldIIDOutput, MarigoldIIDPipeline
11
+ from src.util.image_util import read_img_from_file, img_hwc2chw, img_linear2srgb, is_hdr
12
+ from marigold.util.image_util import float2int
13
+ from src.util.seeding import seed_all
14
+ import logging
15
+
16
+ checkpoint_base = "cvpr/checkpoints"
17
+ seed = 1234
18
+ seed_all(seed)
19
+ if torch.cuda.is_available():
20
+ device = torch.device("cuda")
21
+ else:
22
+ device = torch.device("cpu")
23
+ logging.warning("CUDA is not available. Running on CPU will be slow.")
24
+
25
+ def get_available_models(base_dir=checkpoint_base):
26
+ models = []
27
+ for model_name in os.listdir(base_dir):
28
+ model_path = os.path.join(base_dir, model_name)
29
+ if os.path.isdir(model_path):
30
+ for variant in os.listdir(model_path):
31
+ if os.path.isdir(os.path.join(model_path, variant)):
32
+ models.append(f"{model_name}/{variant}")
33
+ return sorted(models)
34
+
35
+ available_models = get_available_models()
36
+ if not available_models:
37
+ raise RuntimeError(f"No models found in {checkpoint_base}")
38
+
39
+ loaded_models = {}
40
+
41
+ prompts = ["Albedo (diffuse basecolor)", "Camera-space Normal","Roughness", "Metallicness","Irradiance (diffuse lighting)"]
42
+
43
+ def get_demo():
44
+
45
+ def callback(
46
+ photo,
47
+ inference_step,
48
+ selected_model,
49
+ selected_prompt,
50
+ processing_res
51
+ ):
52
+ checkpoint_path = os.path.join(checkpoint_base, selected_model)
53
+ if "rgbx" in selected_model:
54
+ mode = "rgbx"
55
+ prompt = selected_prompt
56
+ else:
57
+ mode = "other"
58
+ prompt = None
59
+ if selected_model not in loaded_models:
60
+ pipe = MarigoldIIDPipeline.from_pretrained(
61
+ checkpoint_path, variant=None, torch_dtype=torch.float32
62
+ ).to(device)
63
+ pipe.mode = mode
64
+ loaded_models[selected_model] = pipe
65
+ else:
66
+ pipe = loaded_models[selected_model]
67
+
68
+ generator = torch.Generator(device=device)
69
+ generator.manual_seed(seed)
70
+ img = read_img_from_file(photo)
71
+ if len(img.shape) == 3:
72
+ img = img_hwc2chw(img)
73
+ if is_hdr(photo):
74
+ img = img_linear2srgb(img)
75
+ if img.shape[0] == 4:
76
+ img = img[:3, :, :]
77
+ rgb_float = torch.from_numpy(img).float()
78
+ input_image = float2int(rgb_float).unsqueeze(0)
79
+
80
+ if "rgbx" in selected_model:
81
+ pipe.prompt = prompt
82
+
83
+ pipe_out: MarigoldIIDOutput = pipe(
84
+ input_image,
85
+ denoising_steps=inference_step,
86
+ ensemble_size=1,
87
+ processing_res=processing_res,
88
+ match_input_res=1,
89
+ batch_size=0,
90
+ show_progress_bar=False,
91
+ resample_method="bilinear",
92
+ generator=generator,
93
+ )
94
+ target_pred = pipe_out["albedo"].array
95
+ if "Metallicness" in prompt or "Roughness" in prompt:
96
+ target_pred = np.repeat(target_pred[0:1,:], 3, axis=0)
97
+ generated_image = target_pred.transpose(1, 2, 0)
98
+ if generated_image.dtype != np.uint8:
99
+ generated_image = np.clip(generated_image, 0, 1)
100
+ generated_image = (generated_image * 255).astype(np.uint8)
101
+
102
+ npy_path = os.path.join("tmp", f"target_pred.npy")
103
+ np.save(npy_path, target_pred)
104
+
105
+ from PIL import Image
106
+ png_path = os.path.join("tmp", f"target_pred.png")
107
+ Image.fromarray(generated_image).save(png_path)
108
+
109
+ return png_path, npy_path, generated_image
110
+
111
+ block = gr.Blocks()
112
+ with block:
113
+ with gr.Row():
114
+ gr.Markdown("## OSU albedo demo")
115
+ with gr.Row():
116
+ # Input side
117
+ with gr.Column():
118
+ gr.Markdown("### Given Image")
119
+ photo = gr.Image(label="Photo",type="filepath")
120
+
121
+ gr.Markdown("### Parameters")
122
+ run_button = gr.Button(value="Run")
123
+ with gr.Accordion("Advanced options", open=False):
124
+ inference_step = gr.Slider(
125
+ label="Inference Step",
126
+ minimum=1,
127
+ maximum=100,
128
+ step=1,
129
+ value=4,
130
+ )
131
+ processing_res = gr.Number(value=0, label="Processing Resolution (processing_res)", precision=0)
132
+
133
+ gr.Markdown("### Select Model")
134
+ model_selector = gr.Dropdown(
135
+ label="Checkpoint",
136
+ choices=available_models,
137
+ value=available_models[0]
138
+ )
139
+
140
+ gr.Markdown("### Select Prompt (only for rgbx models)")
141
+ prompt_selector = gr.Dropdown(
142
+ label="Prompts",
143
+ choices=prompts,
144
+ value=prompts[0]
145
+ )
146
+
147
+ # Output side
148
+ with gr.Column():
149
+ gr.Markdown("### Output Gallery")
150
+ result_image = gr.Image(label="Output Image", interactive=False)
151
+ result_png = gr.File(label="Download Generated Image (.png)")
152
+ result_npy = gr.File(label="Download Target Albedo (.npy)")
153
+
154
+ inputs = [
155
+ photo,
156
+ inference_step,
157
+ model_selector,
158
+ prompt_selector,
159
+ processing_res
160
+ ]
161
+ outputs = [result_png, result_npy, result_image]
162
+ run_button.click(fn=callback, inputs=inputs, outputs=outputs, queue=True)
163
+
164
+ return block
165
 
 
 
166
 
167
+ if __name__ == "__main__":
168
+ demo = get_demo()
169
+ demo.queue(max_size=1)
170
+ demo.launch(share=True)