ksangk commited on
Commit
b62c832
·
1 Parent(s): 862fe37

update hf download

Browse files
Files changed (2) hide show
  1. app.py +11 -15
  2. requirements.txt +1 -1
app.py CHANGED
@@ -7,6 +7,7 @@ import copy
7
  from omegaconf import OmegaConf
8
  from torchvision.transforms import v2
9
  from torchvision.transforms.functional import to_pil_image
 
10
 
11
  from chord import ChordModel
12
  from chord.module import make
@@ -26,6 +27,8 @@ EXAMPLES_USECASE_3 = [
26
  ]
27
 
28
  MODEL_OBJ = None
 
 
29
  def load_model(ckpt_path):
30
  print("Loading model from:", ckpt_path)
31
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -43,7 +46,6 @@ def run_model(model, img: Image.Image):
43
  x = v2.Resize(size=(1024, 1024), antialias=True)(image).unsqueeze(0)
44
  with torch.no_grad() as no_grad, torch.autocast(device_type="cuda") as amp:
45
  output = model(x)
46
- output.update({"input": image})
47
  return output
48
 
49
  def relit(model, maps):
@@ -60,12 +62,12 @@ def relit(model, maps):
60
  rgb = model.model.compute_render(maps, camera, pos, light).squeeze(0).permute(0,3,1,2) # GxBxHxWxC -> BxCxHxW
61
  return torch.clamp(rgb_to_srgb(rgb), 0, 1)
62
 
63
- def inference(img, ckpt_path):
64
  global MODEL_OBJ
65
 
66
- if MODEL_OBJ is None or getattr(MODEL_OBJ, "_ckpt", None) != ckpt_path:
67
- MODEL_OBJ = load_model(ckpt_path)
68
- MODEL_OBJ._ckpt = ckpt_path # store path inside object
69
 
70
  if img is None:
71
  return None, None, None, None, None
@@ -85,14 +87,9 @@ def inference(img, ckpt_path):
85
 
86
  with gr.Blocks(title="Chord") as demo:
87
 
88
- gr.Markdown("# **Chord: Chain of Rendering Decomposition for PBR Material Estimation from Generated Texture images**")
89
- ckpt_path = gr.Textbox(
90
- label="Model Checkpoint Path",
91
- value="chord_v1.ckpt",
92
- placeholder="Path to your model checkpoint",
93
- )
94
- gr.Markdown("Upload an image or select an example to estimate PBR channels and render the result under custom lighting.")
95
-
96
  with gr.Row():
97
  with gr.Column():
98
  input_img = gr.Image(type="pil", label="Input Image", height=512)
@@ -132,10 +129,9 @@ with gr.Blocks(title="Chord") as demo:
132
 
133
  run_button.click(
134
  inference,
135
- inputs=[input_img, ckpt_path],
136
  outputs=[basecolor_out, normal_out, roughness_out, metallic_out, render_out]
137
  )
138
 
139
-
140
  if __name__ == "__main__":
141
  demo.launch()
 
7
  from omegaconf import OmegaConf
8
  from torchvision.transforms import v2
9
  from torchvision.transforms.functional import to_pil_image
10
+ from huggingface_hub import hf_hub_download
11
 
12
  from chord import ChordModel
13
  from chord.module import make
 
27
  ]
28
 
29
  MODEL_OBJ = None
30
+ #MODEL_CKPT_PATH = hf_hub_download(repo_id="Ubisoft/ubisoft-laforge-chord", filename="chord_v1.ckpt")
31
+ MODEL_CKPT_PATH = hf_hub_download(repo_id="ksangk/Chord-V1-ckpt", filename="chord_v1.ckpt")
32
  def load_model(ckpt_path):
33
  print("Loading model from:", ckpt_path)
34
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
46
  x = v2.Resize(size=(1024, 1024), antialias=True)(image).unsqueeze(0)
47
  with torch.no_grad() as no_grad, torch.autocast(device_type="cuda") as amp:
48
  output = model(x)
 
49
  return output
50
 
51
  def relit(model, maps):
 
62
  rgb = model.model.compute_render(maps, camera, pos, light).squeeze(0).permute(0,3,1,2) # GxBxHxWxC -> BxCxHxW
63
  return torch.clamp(rgb_to_srgb(rgb), 0, 1)
64
 
65
+ def inference(img):
66
  global MODEL_OBJ
67
 
68
+ if MODEL_OBJ is None or getattr(MODEL_OBJ, "_ckpt", None) != MODEL_CKPT_PATH:
69
+ MODEL_OBJ = load_model(MODEL_CKPT_PATH)
70
+ MODEL_OBJ._ckpt = MODEL_CKPT_PATH # store path inside object
71
 
72
  if img is None:
73
  return None, None, None, None, None
 
87
 
88
  with gr.Blocks(title="Chord") as demo:
89
 
90
+ gr.Markdown("# **Chord: Chain of Rendering Decomposition for PBR Material Estimation from Generated Texture Images**")
91
+ gr.Markdown("Upload an image or select an example to estimate PBR channels.")
92
+
 
 
 
 
 
93
  with gr.Row():
94
  with gr.Column():
95
  input_img = gr.Image(type="pil", label="Input Image", height=512)
 
129
 
130
  run_button.click(
131
  inference,
132
+ inputs=[input_img],
133
  outputs=[basecolor_out, normal_out, roughness_out, metallic_out, render_out]
134
  )
135
 
 
136
  if __name__ == "__main__":
137
  demo.launch()
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
- huggingface_hub
2
  diffusers
3
  transformers
4
  typer
 
1
+ huggingface_hub[hf_xet]
2
  diffusers
3
  transformers
4
  typer