akhaliq HF Staff commited on
Commit
30a638e
·
verified ·
1 Parent(s): e3ba680

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -0
app.py CHANGED
@@ -47,8 +47,14 @@ def segment(image: Image.Image, text: str, threshold: float, mask_threshold: flo
47
  return None, "❌ Please upload an image."
48
 
49
  try:
 
50
  inputs = processor(images=image, text=text.strip(), return_tensors="pt").to(device)
51
 
 
 
 
 
 
52
  with torch.no_grad():
53
  outputs = model(**inputs)
54
 
@@ -147,6 +153,20 @@ with gr.Blocks(
147
 
148
  segment_btn = gr.Button("🎯 Segment", variant="primary", size="lg")
149
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  # Clear button handler
151
  clear_btn.click(
152
  fn=clear_all,
 
47
  return None, "❌ Please upload an image."
48
 
49
  try:
50
+ # Ensure inputs match model dtype
51
  inputs = processor(images=image, text=text.strip(), return_tensors="pt").to(device)
52
 
53
+ # Convert inputs to match model dtype
54
+ for key in inputs:
55
+ if inputs[key].dtype == torch.float32:
56
+ inputs[key] = inputs[key].to(model.dtype)
57
+
58
  with torch.no_grad():
59
  outputs = model(**inputs)
60
 
 
153
 
154
  segment_btn = gr.Button("🎯 Segment", variant="primary", size="lg")
155
 
156
+ # Add some example prompts
157
+ gr.Examples(
158
+ examples=[
159
+ ["examples/person.jpg", "person"],
160
+ ["examples/car.jpg", "car"],
161
+ ["examples/dog.jpg", "dog"],
162
+ ["examples/building.jpg", "building"],
163
+ ],
164
+ inputs=[image_input, text_input],
165
+ outputs=[image_output, info_output],
166
+ fn=segment_example,
167
+ cache_examples=True,
168
+ )
169
+
170
  # Clear button handler
171
  clear_btn.click(
172
  fn=clear_all,