PubAccount commited on
Commit
dec68f2
ยท
verified ยท
1 Parent(s): 33a707b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +317 -10
app.py CHANGED
@@ -1,14 +1,321 @@
1
- import gradio as gr
2
- import spaces
 
 
 
 
 
3
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- zero = torch.Tensor([0]).cuda()
6
- print(zero.device) # <-- 'cpu' ๐Ÿค”
 
 
 
7
 
8
- @spaces.GPU
9
- def greet(n):
10
- print(zero.device) # <-- 'cuda:0' ๐Ÿค—
11
- return f"Hello {zero + n} Tensor"
 
 
 
 
 
 
 
12
 
13
- demo = gr.Interface(fn=greet, inputs=gr.Number(), outputs=gr.Text())
14
- demo.launch()
 
1
+ """
2
+ app.py โ€” HeightAdaptor Hugging Face Spaces App
3
+ Backbone : stable-diffusion-v1-5/stable-diffusion-v1-5
4
+ Adaptor : UEXdo/HeightAdaptor-weight
5
+ """
6
+
7
+ import os, io
8
  import torch
9
+ import numpy as np
10
+ import matplotlib; matplotlib.use("Agg")
11
+ import matplotlib.pyplot as plt
12
+ from PIL import Image
13
+ from torch.nn import functional as F
14
+ from diffusers import StableDiffusionPipeline
15
+ from huggingface_hub import snapshot_download
16
+ from peft import PeftModel
17
+ import gradio as gr
18
+
19
+ # โ”€โ”€ ZeroGPU compatibility๏ผˆๆ—  spaces ๅบ“ๆ—ถ่‡ชๅŠจ้™็บง๏ผ‰โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
20
+ try:
21
+ import spaces
22
+ except ImportError:
23
+ class spaces:
24
+ @staticmethod
25
+ def GPU(duration=120):
26
+ return lambda fn: fn
27
+
28
+ from networks.semantic_head import SemanticHead
29
+ from networks.height_head import HeightHead
30
+ from networks.decoder import Decoder
31
+
32
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
33
+ # ๅธธ้‡ & ้…็ฝฎ
34
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
35
+ RGB_LATENT_SCALE = 0.18215
36
+
37
+ # ้€š่ฟ‡็Žฏๅขƒๅ˜้‡ๅฏ่ฆ†็›–๏ผŒๅฆๅˆ™ไฝฟ็”จ้ป˜่ฎค HF Repo ID
38
+ SD_MODEL_ID = os.environ.get("SD_MODEL_ID", "stable-diffusion-v1-5/stable-diffusion-v1-5")
39
+ ADAPTOR_REPO = os.environ.get("ADAPTOR_MODEL_ID", "UEXdo/HeightAdaptor-weight")
40
+
41
+ DATASET_CFG = {
42
+ "OpenDC": {"classes_num": 8},
43
+ "US3D": {"classes_num": 6},
44
+ }
45
+
46
+ LABEL_COLORS = {
47
+ "OpenDC": {
48
+ 0: (50,125,0), 1: (255,0,0), 2: (0,255,0), 3: (255,0,0),
49
+ 4: (255,255,0), 5: (255,255,255), 6: (0,255,255), 7: (0,0,0),
50
+ },
51
+ "US3D": {
52
+ 0: (0,0,0), 1: (0,0,0), 2: (255,0,0),
53
+ 3: (0,255,0), 4: (0,0,255), 5: (255,255,0),
54
+ },
55
+ }
56
+
57
+ TASK_PROMPTS = {
58
+ "Height Estimation": "Image to height map",
59
+ "Semantic Segmentation": "Image to semantic segmentation",
60
+ }
61
+
62
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
63
+ # ๅฏๅŠจๆ—ถไธ‹่ฝฝ Adaptor ๆƒ้‡๏ผˆ็ผ“ๅญ˜ๅˆฐๆœฌๅœฐ๏ผŒๅŽ็ปญๆ— ้œ€้‡ๅคไธ‹่ฝฝ๏ผ‰
64
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
65
+ print(f"๐Ÿ“ฆ Downloading adaptor weights from {ADAPTOR_REPO} ...")
66
+ ADAPTOR_DIR = snapshot_download(repo_id=ADAPTOR_REPO)
67
+ print(f"โœ… Weights cached at: {ADAPTOR_DIR}")
68
+
69
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
70
+ # ๆจกๅž‹็ฎก็†๏ผˆไธป่ฟ›็จ‹็ปดๆŠค CPU ๆจกๅž‹๏ผŒGPU ๅญ่ฟ›็จ‹ copy-on-use๏ผ‰
71
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
72
+ _model = None
73
+ _model_key = None # (dataset_name, h_type)
74
+
75
+
76
+ def build_model(dataset_name: str, h_type: str) -> StableDiffusionPipeline:
77
+ """ไปŽ HF Hub ๆ‹‰ๅ–ๅŸบ็ก€ๆจกๅž‹๏ผŒๅ ๅŠ  LoRA + ไธ‰ไธช่‡ชๅฎšไน‰ Head๏ผŒ่ฟ”ๅ›ž CPU ๆจกๅž‹ใ€‚"""
78
+ classes_num = DATASET_CFG[dataset_name]["classes_num"]
79
+ print(f"๐Ÿ”ง Building model โ€” dataset={dataset_name}, h_type={h_type}")
80
+
81
+ # 1. ๅŠ ่ฝฝ SD v1.5 ๅŸบ็ก€ Pipeline
82
+ pipe = StableDiffusionPipeline.from_pretrained(
83
+ SD_MODEL_ID,
84
+ torch_dtype=torch.float32,
85
+ safety_checker=None,
86
+ requires_safety_checker=False,
87
+ )
88
+
89
+ # 2. ็”จ PEFT ๆŠŠ LoRA ๆƒ้‡ๆณจๅ…ฅ UNet
90
+ pipe.unet = PeftModel.from_pretrained(
91
+ pipe.unet,
92
+ os.path.join(ADAPTOR_DIR, "lora"),
93
+ )
94
+
95
+ # 3. ๅŠ ่ฝฝ Decoder
96
+ pipe.decoder = Decoder(in_channel=320)
97
+ pipe.decoder.load_state_dict(
98
+ torch.load(os.path.join(ADAPTOR_DIR, "decoder.pth"), map_location="cpu"))
99
+ pipe.decoder.eval()
100
+
101
+ # 4. ๅŠ ่ฝฝ HeightHead
102
+ pipe.height_head = HeightHead(in_channels=192, h_type=h_type)
103
+ pipe.height_head.load_state_dict(
104
+ torch.load(os.path.join(ADAPTOR_DIR, "height_head.pth"), map_location="cpu"))
105
+ pipe.height_head.eval()
106
+
107
+ # 5. ๅŠ ่ฝฝ SemanticHead๏ผˆ็ฑปๅˆซๆ•ฐ็”ฑ dataset ๅ†ณๅฎš๏ผ‰
108
+ pipe.semantic_head = SemanticHead(in_channels=192, num_classes=classes_num)
109
+ pipe.semantic_head.load_state_dict(
110
+ torch.load(os.path.join(ADAPTOR_DIR, "semantic_head.pth"), map_location="cpu"))
111
+ pipe.semantic_head.eval()
112
+
113
+ print("โœ… Model ready (on CPU).")
114
+ return pipe
115
+
116
+
117
+ def reload_model(dataset_name: str, h_type: str) -> str:
118
+ """
119
+ ๅœจไธป่ฟ›็จ‹ไธญ้‡ๅปบๆจกๅž‹๏ผŒไพ› Gradio ๆŒ‰้’ฎ่ฐƒ็”จใ€‚
120
+ ๆณจๆ„๏ผšๆญคๅ‡ฝๆ•ฐ **ไธๅŠ ** @spaces.GPU๏ผŒ็›ดๆŽฅ่ฟ่กŒๅœจไธป่ฟ›็จ‹๏ผŒ
121
+ ๅ…จๅฑ€ _model ๆ›ดๆ–ฐๅŽ๏ผŒไธ‹ไธ€ๆฌก @spaces.GPU ่ฐƒ็”จไผš fork ๅˆฐๆ–ฐๆจกๅž‹ใ€‚
122
+ """
123
+ global _model, _model_key
124
+ key = (dataset_name, h_type)
125
+ if _model is not None and _model_key == key:
126
+ return f"โœ… Already loaded โ€” **{dataset_name}** / **{h_type}**"
127
+ _model = build_model(dataset_name, h_type)
128
+ _model_key = key
129
+ return f"โœ… Model loaded โ€” **{dataset_name}** / **{h_type}**"
130
+
131
+
132
+ # ๅฏๅŠจๆ—ถ้ข„ๅŠ ่ฝฝ้ป˜่ฎคๆจกๅž‹๏ผˆOpenDC / ER๏ผ‰
133
+ reload_model("OpenDC", "ER")
134
+
135
+
136
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
137
+ # VAE / UNet forward๏ผˆ็งป้™คไบ† DistributedDataParallel ๅˆ†ๆ”ฏ๏ผŒ
138
+ # Spaces ๅ•ๅกๅœบๆ™ฏไธ้œ€่ฆ๏ผ‰
139
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
140
+ def _vae_encode(pipe, x: torch.Tensor):
141
+ """้€š่ฟ‡ VAE Encoder ๅ‰ๅ‘๏ผŒ่ฟ”ๅ›ž (ๆœ€็ปˆ็‰นๅพ, ไธญ้—ด็‰นๅพๅˆ—่กจ)ใ€‚"""
142
+ enc = pipe.vae.encoder
143
+ x = enc.conv_in(x)
144
+ feats = []
145
+ for blk in enc.down_blocks:
146
+ x = blk(x)
147
+ feats.append(x)
148
+ x = enc.mid_block(x)
149
+ x = enc.conv_norm_out(x)
150
+ x = enc.conv_act(x)
151
+ x = enc.conv_out(x)
152
+ return x, feats[:-1] # ไธŽๅŽŸๅง‹ไปฃ็ ไธ€่‡ด๏ผŒไธขๅผƒๆœ€ๅŽไธ€ๅฑ‚็‰นๅพ
153
+
154
+
155
+ def _unet_forward(unet, sample, timestep, enc_hs):
156
+ t_emb = unet.get_time_embed(sample=sample, timestep=timestep)
157
+ emb = unet.time_embedding(t_emb)
158
+ enc_hs = unet.process_encoder_hidden_states(
159
+ encoder_hidden_states=enc_hs, added_cond_kwargs=None)
160
+
161
+ x = unet.conv_in(sample)
162
+ skips = (x,)
163
+ for blk in unet.down_blocks:
164
+ x, res = blk(hidden_states=x, temb=emb, encoder_hidden_states=enc_hs)
165
+ skips += res
166
+
167
+ x = unet.mid_block(x, emb, encoder_hidden_states=enc_hs)
168
+
169
+ for blk in unet.up_blocks:
170
+ res = skips[-len(blk.resnets):]
171
+ skips = skips[:-len(blk.resnets)]
172
+ x = blk(hidden_states=x, temb=emb,
173
+ res_hidden_states_tuple=res, encoder_hidden_states=enc_hs)
174
+ return x
175
+
176
+
177
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
178
+ # GPU ๆŽจ็†๏ผˆ็”จ @spaces.GPU ่ฃ…้ฅฐ๏ผŒ็”ณ่ฏทๆœ€ๅคš 120s GPU๏ผ‰
179
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
180
+ @spaces.GPU(duration=120)
181
+ @torch.no_grad()
182
+ def run_inference(
183
+ image: Image.Image,
184
+ task: str,
185
+ dataset_name: str,
186
+ h_type: str,
187
+ mode_type: str,
188
+ ):
189
+ if image is None:
190
+ return None, "โš ๏ธ Please upload an image first."
191
+ if _model is None:
192
+ return None, "โš ๏ธ Model not loaded โ€” click **Load / Reload Model**."
193
+
194
+ device = "cuda"
195
+ pipe = _model
196
+ pipe.to(device) # ZeroGPU ๅญ่ฟ›็จ‹ๆ‹ฟๅˆฐ CPU ๅ‰ฏๆœฌๅŽ็งปๅˆฐ GPU
197
+
198
+ try:
199
+ # โ”€โ”€ 1. ๆ–‡ๆœฌ็ผ–็  โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
200
+ tokens = pipe.tokenizer(
201
+ TASK_PROMPTS[task], padding="max_length", truncation=True,
202
+ max_length=pipe.tokenizer.model_max_length, return_tensors="pt")
203
+ text_emb = pipe.text_encoder(tokens.input_ids.to(device))[0].float()
204
+ # text_emb: [1, 77, 768] (SD v1.5 ็š„ text dim ไธบ 768)
205
+
206
+ # โ”€โ”€ 2. ๅ›พๅƒ้ข„ๅค„็† โ†’ [1, 3, 512, 512] โˆˆ [-1, 1] โ”€โ”€โ”€โ”€โ”€โ”€
207
+ img = image.convert("RGB").resize((512, 512), Image.BILINEAR)
208
+ arr = np.array(img, dtype=np.float32).transpose(2, 0, 1)
209
+ norm = (torch.from_numpy(arr) / 255.0 * 2.0 - 1.0).unsqueeze(0).to(device)
210
+
211
+ # โ”€โ”€ 3. VAE ็ผ–็  โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
212
+ h, h_list = _vae_encode(pipe, norm)
213
+ moments = pipe.vae.quant_conv(h)
214
+ mean, lv = torch.chunk(moments, 2, dim=1)
215
+ latents = (mean + torch.exp(0.5 * lv) * torch.randn_like(mean)) * RGB_LATENT_SCALE
216
+
217
+ # โ”€โ”€ 4. UNet + ่‡ชๅฎšไน‰ Decoder โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
218
+ ts = torch.ones([latents.shape[0]], device=device) * 999
219
+ unet_o = _unet_forward(pipe.unet, latents, ts, text_emb)
220
+ dec_o = pipe.decoder(unet_o, res_list=h_list[::-1])
221
+
222
+ # โ”€โ”€ 5. ไปปๅŠก Head โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
223
+ h_out = pipe.height_head(dec_o)
224
+ s_out = pipe.semantic_head(dec_o)
225
+
226
+ # โ”€โ”€ 6. ๅŽๅค„็† & ๅฏ่ง†ๅŒ– โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
227
+ if mode_type == "Height Map":
228
+ pred = F.interpolate(h_out[0].cpu(), (512, 512),
229
+ mode="bilinear", align_corners=False)
230
+ pred = ((pred + 1.0) / 2.0).clamp(0, 1).squeeze().numpy()
231
+
232
+ fig, ax = plt.subplots(figsize=(6, 5), tight_layout=True)
233
+ im = ax.imshow(pred, cmap="plasma")
234
+ fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
235
+ ax.set_title("Predicted Height Map"); ax.axis("off")
236
+ buf = io.BytesIO()
237
+ fig.savefig(buf, format="png", dpi=150)
238
+ plt.close(fig); buf.seek(0)
239
+ out_img = Image.open(buf).copy()
240
+ info = (f"Normalized range: [{pred.min():.4f}, {pred.max():.4f}]\n"
241
+ "(0 โ‰ˆ 0 m, 1 โ‰ˆ 50 m before denormalization)")
242
+
243
+ else: # Semantic Map
244
+ pred = F.interpolate(s_out, (512, 512), mode="bilinear", align_corners=False)
245
+ argmax = torch.argmax(pred, dim=1).squeeze().cpu().numpy()
246
+ canvas = np.zeros((512, 512, 3), dtype=np.uint8)
247
+ for lbl, col in LABEL_COLORS[dataset_name].items():
248
+ canvas[argmax == lbl] = col
249
+ out_img = Image.fromarray(canvas)
250
+ info = f"Detected class indices: {np.unique(argmax).tolist()}"
251
+
252
+ return out_img, info
253
+
254
+ finally:
255
+ # ZeroGPU ๅญ่ฟ›็จ‹็ป“ๆŸๅŽ GPU ๅ†…ๅญ˜่‡ชๅŠจ้‡Šๆ”พ๏ผŒ
256
+ # ่ฟ™้‡Œๆ˜พๅผ็งปๅ›ž CPU ๅชๆ˜ฏ้ขๅค–ไฟ้™ฉ
257
+ pipe.to("cpu")
258
+ torch.cuda.empty_cache()
259
+
260
+
261
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
262
+ # Gradio UI
263
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
264
+ with gr.Blocks(title="HeightAdaptor") as demo:
265
+ gr.Markdown("""
266
+ # ๐Ÿ™๏ธ HeightAdaptor
267
+ **Remote Sensing Image โ†’ Height Map / Semantic Segmentation**
268
+
269
+ Backbone: `stable-diffusion-v1-5` + LoRA adaptor (`UEXdo/HeightAdaptor-weight`) + ่‡ชๅฎšไน‰ Task Heads
270
+ """)
271
+
272
+ with gr.Row():
273
+ # โ”€โ”€ ๅทฆๆ ๏ผš่พ“ๅ…ฅ & ้…็ฝฎ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
274
+ with gr.Column(scale=1):
275
+ inp_img = gr.Image(type="pil", label="๐Ÿ“ท Input RGB Image")
276
+
277
+ with gr.Group():
278
+ gr.Markdown("#### โš™๏ธ Model Config")
279
+ dataset_radio = gr.Radio(
280
+ ["OpenDC", "US3D"], value="OpenDC", label="Dataset")
281
+ h_type_radio = gr.Radio(
282
+ ["ER", "DR"], value="ER", label="Height Type (h_type)")
283
+ load_btn = gr.Button("๐Ÿ”„ Load / Reload Model", variant="secondary")
284
+ load_info = gr.Markdown("โœ… Default model active (OpenDC / ER)")
285
+
286
+ with gr.Group():
287
+ gr.Markdown("#### ๐ŸŽฏ Inference Config")
288
+ task_radio = gr.Radio(
289
+ ["Height Estimation", "Semantic Segmentation"],
290
+ value="Height Estimation", label="Task")
291
+ mode_radio = gr.Radio(
292
+ ["Height Map", "Semantic Map"],
293
+ value="Height Map", label="Output Mode")
294
+
295
+ run_btn = gr.Button("๐Ÿš€ Run Inference", variant="primary", size="lg")
296
+
297
+ # โ”€โ”€ ๅณๆ ๏ผš่พ“ๅ‡บ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
298
+ with gr.Column(scale=1):
299
+ out_img = gr.Image(type="pil", label="๐Ÿ“Š Output")
300
+ out_info = gr.Textbox(label="โ„น๏ธ Info", interactive=False, lines=3)
301
 
302
+ gr.Markdown("""
303
+ ---
304
+ > โš ๏ธ **ๅˆ‡ๆข Dataset / Height Type ๅŽ๏ผŒ่ฏทๅ…ˆ็‚นๅ‡ป Load / Reload Model ๅ†ๆŽจ็†ใ€‚**
305
+ > ๅ›พๅƒไผš่‡ชๅŠจ็ผฉๆ”พ่‡ณ 512 ร— 512๏ผŒGPU ๆŽจ็†็บฆ้œ€ 10โ€“30 ็ง’ใ€‚
306
+ """)
307
 
308
+ # โ”€โ”€ ไบ‹ไปถ็ป‘ๅฎš โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
309
+ load_btn.click(
310
+ fn=reload_model,
311
+ inputs=[dataset_radio, h_type_radio],
312
+ outputs=[load_info],
313
+ )
314
+ run_btn.click(
315
+ fn=run_inference,
316
+ inputs=[inp_img, task_radio, dataset_radio, h_type_radio, mode_radio],
317
+ outputs=[out_img, out_info],
318
+ )
319
 
320
+ if __name__ == "__main__":
321
+ demo.launch()