zhouyik commited on
Commit
5814bce
Β·
verified Β·
1 Parent(s): 7d2576b

Upload ./app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +28 -13
app.py CHANGED
@@ -1,5 +1,6 @@
1
  # Modified from https://huggingface.co/spaces/PolyU-ChenLab/UniPixel/blob/main/app.py
2
-
 
3
  import random
4
  import re
5
  import colorsys
@@ -11,13 +12,22 @@ import imageio.v3 as iio
11
 
12
  import torch
13
  from torchvision.transforms.functional import to_pil_image
 
14
 
15
  import spaces
16
  import gradio as gr
17
 
 
 
 
 
 
 
 
 
18
  from transformers import Qwen3VLForConditionalGeneration, AutoProcessor
19
- from .sam2 import VQ_SAM2, VQ_SAM2Config, SAM2Config
20
- from .visualizer import sample_color, draw_mask
21
 
22
  class DirectResize:
23
  def __init__(self, target_length: int) -> None:
@@ -85,10 +95,8 @@ function init() {
85
  document.querySelector('main').style.maxWidth = '1536px'
86
  }
87
  document.getElementById('query_1').addEventListener('keydown', function f1(e) { if (e.key === 'Enter') { document.getElementById('submit_1').click() } })
88
- document.getElementById('query_2').addEventListener('keydown', function f2(e) { if (e.key === 'Enter') { document.getElementById('submit_2').click() } })
89
- document.getElementById('query_3').addEventListener('keydown', function f3(e) { if (e.key === 'Enter') { document.getElementById('submit_3').click() } })
90
- document.getElementById('query_4').addEventListener('keydown', function f4(e) { if (e.key === 'Enter') { document.getElementById('submit_4').click() } })
91
  }
 
92
  """
93
 
94
  device = torch.device('cuda')
@@ -100,10 +108,12 @@ model = Qwen3VLForConditionalGeneration.from_pretrained(
100
  processor = AutoProcessor.from_pretrained(MODEL)
101
 
102
  # build vq-sam2 model
 
 
103
  CODEBOOK_SIZE = 256
104
  CODEBOOK_DEPTH = 2
105
  sam2_config = SAM2Config(
106
- ckpt_path=MODEL+"/sam2.1_hiera_large.pt",
107
  )
108
  vq_sam2_config = VQ_SAM2Config(
109
  sam2_config=sam2_config,
@@ -113,7 +123,7 @@ vq_sam2_config = VQ_SAM2Config(
113
  latent_dim=256,
114
  )
115
  vq_sam2 = VQ_SAM2(vq_sam2_config).cuda().eval()
116
- state = torch.load(MODEL+"/mask_tokenizer_256x2.pth", map_location="cpu")
117
  vq_sam2.load_state_dict(state)
118
  sam2_image_processor = DirectResize(1024)
119
 
@@ -126,22 +136,23 @@ color_map_light = {
126
  }
127
 
128
  def enable_btns():
129
- return (gr.Button(interactive=True), ) * 4
130
 
131
 
132
  def disable_btns():
133
- return (gr.Button(interactive=False), ) * 4
134
 
135
 
136
  def reset_seg():
137
- return 16, gr.Button(interactive=False)
138
 
139
 
140
  def reset_reg():
141
- return 1, gr.Button(interactive=False)
142
 
143
  @spaces.GPU
144
  def infer_seg(media, query):
 
145
  global model
146
 
147
  if not media:
@@ -152,7 +163,7 @@ def infer_seg(media, query):
152
  gr.Warning('Please provide a text prompt.')
153
  return None, None, None
154
 
155
- image = Image.open(path).convert('RGB')
156
  ori_width, ori_height = image.size
157
  messages = [
158
  {
@@ -190,6 +201,9 @@ def infer_seg(media, query):
190
  output_text = processor.batch_decode(
191
  generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
192
  )[0]
 
 
 
193
 
194
  quant_ids = extract_mt_token_ids_v1(output_text)
195
  if len(quant_ids) % CODEBOOK_DEPTH != 0:
@@ -284,6 +298,7 @@ def build_demo():
284
  # with gr.Tab('Mask Understanding'):
285
  # pass
286
 
 
287
 
288
  if __name__ == '__main__':
289
  demo = build_demo()
 
1
  # Modified from https://huggingface.co/spaces/PolyU-ChenLab/UniPixel/blob/main/app.py
2
+ import os
3
+ from pathlib import Path
4
  import random
5
  import re
6
  import colorsys
 
12
 
13
  import torch
14
  from torchvision.transforms.functional import to_pil_image
15
+ from huggingface_hub import hf_hub_download
16
 
17
  import spaces
18
  import gradio as gr
19
 
20
+ GRADIO_TMP = os.path.join(os.path.dirname(__file__), ".gradio_tmp")
21
+ Path(GRADIO_TMP).mkdir(parents=True, exist_ok=True)
22
+
23
+ os.environ["GRADIO_TEMP_DIR"] = GRADIO_TMP
24
+ os.environ["TMPDIR"] = GRADIO_TMP
25
+ os.environ["TEMP"] = GRADIO_TMP
26
+ os.environ["TMP"] = GRADIO_TMP
27
+
28
  from transformers import Qwen3VLForConditionalGeneration, AutoProcessor
29
+ from sam2 import VQ_SAM2, VQ_SAM2Config, SAM2Config
30
+ from visualizer import sample_color, draw_mask
31
 
32
  class DirectResize:
33
  def __init__(self, target_length: int) -> None:
 
95
  document.querySelector('main').style.maxWidth = '1536px'
96
  }
97
  document.getElementById('query_1').addEventListener('keydown', function f1(e) { if (e.key === 'Enter') { document.getElementById('submit_1').click() } })
 
 
 
98
  }
99
+ window.addEventListener('load', init);
100
  """
101
 
102
  device = torch.device('cuda')
 
108
  processor = AutoProcessor.from_pretrained(MODEL)
109
 
110
  # build vq-sam2 model
111
+ sam2_ckpt_local = hf_hub_download(repo_id=MODEL, filename="sam2.1_hiera_large.pt")
112
+ mask_tokenizer_local = hf_hub_download(repo_id=MODEL, filename="mask_tokenizer_256x2.pth")
113
  CODEBOOK_SIZE = 256
114
  CODEBOOK_DEPTH = 2
115
  sam2_config = SAM2Config(
116
+ ckpt_path=sam2_ckpt_local,
117
  )
118
  vq_sam2_config = VQ_SAM2Config(
119
  sam2_config=sam2_config,
 
123
  latent_dim=256,
124
  )
125
  vq_sam2 = VQ_SAM2(vq_sam2_config).cuda().eval()
126
+ state = torch.load(mask_tokenizer_local, map_location="cpu")
127
  vq_sam2.load_state_dict(state)
128
  sam2_image_processor = DirectResize(1024)
129
 
 
136
  }
137
 
138
  def enable_btns():
139
+ return (gr.update(interactive=True), ) * 4
140
 
141
 
142
  def disable_btns():
143
+ return (gr.update(interactive=False), ) * 4
144
 
145
 
146
  def reset_seg():
147
+ return 16, gr.update(interactive=False)
148
 
149
 
150
  def reset_reg():
151
+ return 1, gr.update(interactive=False)
152
 
153
  @spaces.GPU
154
  def infer_seg(media, query):
155
+ print("=======>>>enter infer seg")
156
  global model
157
 
158
  if not media:
 
163
  gr.Warning('Please provide a text prompt.')
164
  return None, None, None
165
 
166
+ image = Image.open(media).convert('RGB')
167
  ori_width, ori_height = image.size
168
  messages = [
169
  {
 
201
  output_text = processor.batch_decode(
202
  generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
203
  )[0]
204
+
205
+ print("========>>>>output_text", output_text)
206
+ exit(0)
207
 
208
  quant_ids = extract_mt_token_ids_v1(output_text)
209
  if len(quant_ids) % CODEBOOK_DEPTH != 0:
 
298
  # with gr.Tab('Mask Understanding'):
299
  # pass
300
 
301
+ return demo
302
 
303
  if __name__ == '__main__':
304
  demo = build_demo()