Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -36,9 +36,10 @@ SAM_ENCODER_VERSION = "vit_h"
|
|
| 36 |
SAM_CHECKPOINT_PATH = "./sam_vit_h_4b8939.pth"
|
| 37 |
|
| 38 |
# Building GroundingDINO inference model
|
| 39 |
-
groundingdino_model = load_model(model_config_path=GROUNDING_DINO_CONFIG_PATH, model_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH, device="
|
| 40 |
# Building SAM Model and SAM Predictor
|
| 41 |
sam = build_sam(checkpoint=SAM_CHECKPOINT_PATH)
|
|
|
|
| 42 |
sam_predictor = SamPredictor(sam)
|
| 43 |
|
| 44 |
def transform_image(image_pil):
|
|
@@ -128,7 +129,7 @@ def get_mask(image, label):
|
|
| 128 |
sam_predictor.set_image(image)
|
| 129 |
|
| 130 |
transformed_boxes = sam_predictor.transform.apply_boxes_torch(
|
| 131 |
-
boxes_filt, image.shape[:2])
|
| 132 |
|
| 133 |
masks, _, _ = sam_predictor.predict_torch(
|
| 134 |
point_coords=None,
|
|
@@ -359,7 +360,7 @@ with gr.Blocks() as demo:
|
|
| 359 |
text_prompt = gr.Textbox(label="Label")
|
| 360 |
|
| 361 |
with gr.Column(scale=1):
|
| 362 |
-
baseline_gallery = gr.Gallery(label='Output', show_label=True, elem_id="gallery", height=
|
| 363 |
with gr.Accordion("Advanced Option", open=True):
|
| 364 |
seed = gr.Slider(label="Seed", minimum=-1, maximum=999999999, step=1, value=666)
|
| 365 |
gr.Markdown("### Guidelines")
|
|
|
|
| 36 |
SAM_CHECKPOINT_PATH = "./sam_vit_h_4b8939.pth"
|
| 37 |
|
| 38 |
# Building GroundingDINO inference model
|
| 39 |
+
groundingdino_model = load_model(model_config_path=GROUNDING_DINO_CONFIG_PATH, model_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH, device="cuda")
|
| 40 |
# Building SAM Model and SAM Predictor
|
| 41 |
sam = build_sam(checkpoint=SAM_CHECKPOINT_PATH)
|
| 42 |
+
sam.to(device="cuda")
|
| 43 |
sam_predictor = SamPredictor(sam)
|
| 44 |
|
| 45 |
def transform_image(image_pil):
|
|
|
|
| 129 |
sam_predictor.set_image(image)
|
| 130 |
|
| 131 |
transformed_boxes = sam_predictor.transform.apply_boxes_torch(
|
| 132 |
+
boxes_filt, image.shape[:2]).to("cuda")
|
| 133 |
|
| 134 |
masks, _, _ = sam_predictor.predict_torch(
|
| 135 |
point_coords=None,
|
|
|
|
| 360 |
text_prompt = gr.Textbox(label="Label")
|
| 361 |
|
| 362 |
with gr.Column(scale=1):
|
| 363 |
+
baseline_gallery = gr.Gallery(label='Output', show_label=True, elem_id="gallery", height=800, columns=1)
|
| 364 |
with gr.Accordion("Advanced Option", open=True):
|
| 365 |
seed = gr.Slider(label="Seed", minimum=-1, maximum=999999999, step=1, value=666)
|
| 366 |
gr.Markdown("### Guidelines")
|