CountEx / app.py
yifehuang97's picture
Update app.py
985fb95 verified
import os
import json
import gradio as gr
import torch
from PIL import Image, ImageDraw
from transformers import GroundingDinoProcessor
from hf_model import CountEX
from utils import post_process_grounded_object_detection, post_process_grounded_object_detection_with_queries
import google.generativeai as genai
from datetime import datetime
import csv
from pathlib import Path
import uuid
import io
# Try to import HEIC support
try:
from pillow_heif import register_heif_opener
register_heif_opener()
HEIC_SUPPORTED = True
except ImportError:
HEIC_SUPPORTED = False
print("Warning: pillow-heif not installed. HEIC images will not be supported.")
# Try to import HuggingFace Hub
try:
from huggingface_hub import HfApi
HF_HUB_AVAILABLE = True
except ImportError:
HF_HUB_AVAILABLE = False
print("Warning: huggingface_hub not installed.")
# Global variables for model and processor
model = None
processor = None
device = None
hf_api = None
# Data collection directory (local fallback)
DATA_LOG_DIR = Path("uploaded_data")
DATA_LOG_DIR.mkdir(exist_ok=True)
IMAGES_DIR = DATA_LOG_DIR / "images"
IMAGES_DIR.mkdir(exist_ok=True)
DATA_LOG_FILE = DATA_LOG_DIR / "prompts_log.csv"
# HuggingFace Dataset repo for data collection
HF_DATASET_REPO = os.environ.get("HF_DATASET_REPO", "BBVisual/CountEx_UserData")
# Initialize CSV log file with headers if it doesn't exist
if not DATA_LOG_FILE.exists():
with open(DATA_LOG_FILE, "w", newline="") as f:
writer = csv.writer(f)
writer.writerow(["timestamp", "image_filename", "instruction", "pos_caption", "neg_caption", "count"])
# Image processing constants
MAX_IMAGE_SIZE = 1333 # Max dimension (width or height)
ALLOWED_EXTENSIONS = {".jpg", ".jpeg", ".png"}
gemini_api_key = os.environ.get("GEMINI_API_KEY")
# Configure Gemini
genai.configure(api_key=gemini_api_key)
gemini_model = genai.GenerativeModel("gemini-2.0-flash")
PARSING_PROMPT = """Parse sentences of the form "Count A, not B" into two lists—A (include) and B (exclude)—splitting on "and", "or", and commas, and reattaching shared head nouns (e.g., "red and black beans" → "red beans", "black beans").
Rules:
- Remove from B items that are equivalent to items in A (synonyms/variants/abbreviations/regional terms)
- Keep B items that are more specific than A (for fine-grained exclusion)
- If B is more general than A but shares the head noun, remove B (contradictory)
Case 1 — Different head nouns → Keep B
Example 1: Count green apples and red beans, not yellow screws and white rice → A: ["green apples", "red beans"], B: ["yellow screws", "white rice"]
Example 2: Count black beans, not poker chips or nails → A: ["black beans"], B: ["poker chips", "nails"]
Case 2 — Equivalent items → Remove from B
Example 1: Count fries and TV, not chips and television → A: ["fries", "TV"], B: []
Example 2: Count garbanzo beans and couch, not chickpeas and sofa → A: ["garbanzo beans", "couch"], B: []
Case 3 — B more specific than A → Keep B (for fine-grained exclusion)
Example 1: Count apples and beans, not green apples and black beans → A: ["apples", "beans"], B: ["green apples", "black beans"]
Example 2: Count beans, not white beans or yellow beans → A: ["beans"], B: ["white beans", "yellow beans"]
Example 3: Count people, not women → A: ["people"], B: ["women"]
Case 4 — B more general than A → Remove B (contradictory)
Example 1: Count green apples, not apples → A: ["green apples"], B: []
Example 2: Count red beans and green apples, not beans and apples → A: ["red beans", "green apples"], B: []
User instruction: {instruction}
Respond ONLY with a JSON object in this exact format, no other text:
{{"A": ["item1", "item2"], "B": ["item3"]}}
"""
def init_hf_api():
"""Initialize HuggingFace API for dataset upload."""
global hf_api
if not HF_HUB_AVAILABLE:
print("HuggingFace Hub not available")
return None
try:
hf_token = os.environ.get("HF_WRITTE_TOKEN")
if not hf_token:
print("HF_WRITTE_TOKEN not set, data collection disabled")
return None
hf_api = HfApi(token=hf_token)
print(f"HuggingFace API initialized. Dataset repo: {HF_DATASET_REPO}")
return hf_api
except Exception as e:
print(f"Error initializing HuggingFace API: {e}")
return None
def upload_to_hf_dataset(image_bytes, image_filename, data_dict):
"""Upload image and metadata to HuggingFace Dataset."""
global hf_api
if not hf_api:
return False
try:
hf_token = os.environ.get("HF_WRITTE_TOKEN")
# Upload image
hf_api.upload_file(
path_or_fileobj=io.BytesIO(image_bytes),
path_in_repo=f"images/{image_filename}",
repo_id=HF_DATASET_REPO,
repo_type="dataset",
token=hf_token
)
# Upload metadata as individual JSON file (avoids race conditions)
json_filename = image_filename.replace('.jpg', '.json')
json_content = json.dumps(data_dict, indent=2)
hf_api.upload_file(
path_or_fileobj=io.BytesIO(json_content.encode('utf-8')),
path_in_repo=f"metadata/{json_filename}",
repo_id=HF_DATASET_REPO,
repo_type="dataset",
token=hf_token
)
return True
except Exception as e:
print(f"Error uploading to HuggingFace Dataset: {e}")
return False
def save_uploaded_data(image, instruction, pos_caption, neg_caption, count, points=None):
"""
Save uploaded image and prompt data for collection.
Tries HuggingFace Dataset first, falls back to local storage.
"""
global hf_api
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
unique_id = str(uuid.uuid4())[:8]
image_filename = f"{timestamp}_{unique_id}.jpg"
# Prepare image bytes
img_buffer = io.BytesIO()
image.save(img_buffer, format='JPEG', quality=95)
img_bytes = img_buffer.getvalue()
# Data as dict (for JSON)
data_dict = {
"timestamp": timestamp,
"image_filename": image_filename,
"instruction": instruction,
"pos_caption": pos_caption,
"neg_caption": neg_caption,
"count": count,
"points": points if points else [] # normalized coordinates (0-1)
}
# Try HuggingFace Dataset first
if hf_api:
try:
if upload_to_hf_dataset(img_bytes, image_filename, data_dict):
print(f"Saved to HuggingFace Dataset: {image_filename}")
return
except Exception as e:
print(f"HuggingFace upload failed, falling back to local: {e}")
# Fallback to local storage
try:
image_path = IMAGES_DIR / image_filename
image.save(image_path, "JPEG", quality=95)
# Also save as JSON locally
json_path = DATA_LOG_DIR / "metadata"
json_path.mkdir(exist_ok=True)
with open(json_path / image_filename.replace('.jpg', '.json'), 'w') as f:
json.dump(data_dict, f, indent=2)
# Also append to CSV for backward compatibility
with open(DATA_LOG_FILE, "a", newline="") as f:
writer = csv.writer(f)
writer.writerow([timestamp, image_filename, instruction, pos_caption, neg_caption, count])
print(f"Saved locally: {image_filename}")
except Exception as e:
print(f"Error saving data: {e}")
def validate_image(image):
"""
Validate uploaded image format.
Returns (is_valid, error_message)
"""
if image is None:
return False, "Error: Please upload an image."
# Get file extension
if isinstance(image, str):
ext = os.path.splitext(image)[1].lower()
if ext and ext not in ALLOWED_EXTENSIONS:
return False, f"Error: Unsupported format '{ext}'. Only JPG and PNG are supported."
return True, None
def preprocess_image(image):
"""
Preprocess uploaded image: convert format and resize if needed.
"""
# Handle file path input
if isinstance(image, str):
image = Image.open(image)
# Convert to RGB (handles RGBA, P mode, etc.)
if image.mode != "RGB":
if image.mode in ("RGBA", "LA", "P"):
background = Image.new("RGB", image.size, (255, 255, 255))
if image.mode == "P":
image = image.convert("RGBA")
background.paste(image, mask=image.split()[-1] if image.mode == "RGBA" else None)
image = background
else:
image = image.convert("RGB")
# Resize if image is too large
width, height = image.size
if max(width, height) > MAX_IMAGE_SIZE:
scale = MAX_IMAGE_SIZE / max(width, height)
new_width = int(width * scale)
new_height = int(height * scale)
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
print(f"Resized image from {width}x{height} to {new_width}x{new_height}")
return image
def parse_counting_instruction(instruction: str) -> tuple[str, str]:
"""
Parse natural language counting instruction using Gemini 2.0 Flash.
"""
try:
prompt = PARSING_PROMPT.format(instruction=instruction)
response = gemini_model.generate_content(prompt)
response_text = response.text.strip()
# Clean up response - remove markdown code blocks if present
if response_text.startswith("```"):
response_text = response_text.split("```")[1]
if response_text.startswith("json"):
response_text = response_text[4:]
response_text = response_text.strip()
result = json.loads(response_text)
# Convert lists to caption strings
pos_items = result.get("A", [])
neg_items = result.get("B", [])
# Join items with " and " and add period
pos_caption = " and ".join(pos_items) + "." if pos_items else ""
neg_caption = " and ".join(neg_items) + "." if neg_items else "None."
return pos_caption, neg_caption
except Exception as e:
print(f"Error parsing instruction: {e}")
return instruction.strip() + ".", "None."
def load_model():
"""Load model and processor once at startup"""
global model, processor, device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_id = "yifehuang97/CountEx_KC_aug_v3_12140136_v2"
model = CountEX.from_pretrained(model_id, token=os.environ.get("HF_TOKEN"))
model = model.to(torch.bfloat16)
model = model.to(device)
model.eval()
processor_id = "fushh7/llmdet_swin_tiny_hf"
processor = GroundingDinoProcessor.from_pretrained(processor_id)
return model, processor, device
import numpy as np
def discriminative_point_suppression(
points,
neg_points,
pos_queries,
neg_queries,
image_size,
pixel_threshold=5,
similarity_threshold=0.3,
):
"""Discriminative Point Suppression (DPS)"""
if not neg_points or not points:
return points, list(range(len(points))), {}
width, height = image_size
N, M = len(points), len(neg_points)
points_arr = np.array(points) * np.array([width, height])
neg_points_arr = np.array(neg_points) * np.array([width, height])
spatial_dist = np.linalg.norm(
points_arr[:, None, :] - neg_points_arr[None, :, :], axis=-1
)
nearest_neg_idx = spatial_dist.argmin(axis=1)
nearest_neg_dist = spatial_dist.min(axis=1)
spatially_close = nearest_neg_dist < pixel_threshold
pos_q = pos_queries / (np.linalg.norm(pos_queries, axis=-1, keepdims=True) + 1e-8)
neg_q = neg_queries / (np.linalg.norm(neg_queries, axis=-1, keepdims=True) + 1e-8)
matched_neg_q = neg_q[nearest_neg_idx]
query_sim = (pos_q * matched_neg_q).sum(axis=-1)
semantically_similar = query_sim > similarity_threshold
should_suppress = spatially_close & semantically_similar
keep_mask = ~should_suppress
filtered_points = np.array(points)[keep_mask].tolist()
filtered_indices = np.where(keep_mask)[0].tolist()
suppression_info = {
"nearest_neg_idx": nearest_neg_idx.tolist(),
"nearest_neg_dist": nearest_neg_dist.tolist(),
"query_similarity": query_sim.tolist(),
"spatially_close": spatially_close.tolist(),
"semantically_similar": semantically_similar.tolist(),
"suppressed_indices": np.where(should_suppress)[0].tolist(),
}
return filtered_points, filtered_indices, suppression_info
def count_objects(image, instruction, box_threshold, point_radius, point_color):
"""Main inference function for counting objects"""
global model, processor, device
# Validate image format
is_valid, error_msg = validate_image(image)
if not is_valid:
raise gr.Error(error_msg)
if model is None:
load_model()
# Preprocess image
image = preprocess_image(image)
# Parse instruction using Gemini
pos_caption, neg_caption = parse_counting_instruction(instruction)
parsed_info = f"Positive: {pos_caption}\nNegative: {neg_caption}"
# Process positive caption
pos_inputs = processor(
images=image,
text=pos_caption,
return_tensors="pt",
padding=True
)
pos_inputs = pos_inputs.to(device)
pos_inputs['pixel_values'] = pos_inputs['pixel_values'].to(torch.bfloat16)
# Process negative caption
use_neg = bool(neg_caption and neg_caption.strip() and neg_caption != '.' and neg_caption != 'None.')
if not use_neg:
neg_caption = "None."
neg_inputs = processor(
images=image,
text=neg_caption,
return_tensors="pt",
padding=True
)
neg_inputs = {k: v.to(device) for k, v in neg_inputs.items()}
neg_inputs['pixel_values'] = neg_inputs['pixel_values'].to(torch.bfloat16)
pos_inputs['neg_token_type_ids'] = neg_inputs['token_type_ids']
pos_inputs['neg_attention_mask'] = neg_inputs['attention_mask']
pos_inputs['neg_pixel_mask'] = neg_inputs['pixel_mask']
pos_inputs['neg_pixel_values'] = neg_inputs['pixel_values']
pos_inputs['neg_input_ids'] = neg_inputs['input_ids']
pos_inputs['use_neg'] = True
# Run inference
with torch.no_grad():
outputs = model(**pos_inputs)
outputs["pred_points"] = outputs["pred_boxes"][:, :, :2]
outputs["pred_logits"] = outputs["logits"]
threshold = box_threshold if box_threshold > 0 else model.box_threshold
pos_queries = outputs["pos_queries"].squeeze(0).float()
neg_queries = outputs["neg_queries"].squeeze(0).float()
pos_queries = pos_queries[-1].squeeze(0)
neg_queries = neg_queries[-1].squeeze(0)
pos_queries = pos_queries.unsqueeze(0)
neg_queries = neg_queries.unsqueeze(0)
results = post_process_grounded_object_detection_with_queries(outputs, pos_queries, box_threshold=threshold)[0]
boxes = results["boxes"]
boxes = [box.tolist() for box in boxes]
points = [[box[0], box[1]] for box in boxes]
neg_points = []
neg_results = None
if "neg_pred_boxes" in outputs and "neg_logits" in outputs:
neg_outputs = outputs.copy()
neg_outputs["pred_boxes"] = outputs["neg_pred_boxes"]
neg_outputs["logits"] = outputs["neg_logits"]
neg_outputs["pred_points"] = outputs["neg_pred_boxes"][:, :, :2]
neg_outputs["pred_logits"] = outputs["neg_logits"]
neg_results = post_process_grounded_object_detection_with_queries(neg_outputs, neg_queries, box_threshold=0.5)[0]
neg_boxes = neg_results["boxes"]
neg_boxes = [box.tolist() for box in neg_boxes]
neg_points = [[box[0], box[1]] for box in neg_boxes]
pos_queries_np = results["queries"].cpu().numpy()
neg_queries_np = neg_results["queries"].cpu().numpy() if neg_results else np.array([])
img_size = image.size
if len(neg_points) > 0 and len(neg_queries_np) > 0:
filtered_points, kept_indices, suppression_info = discriminative_point_suppression(
points,
neg_points,
pos_queries_np,
neg_queries_np,
image_size=img_size,
pixel_threshold=5,
similarity_threshold=0.3,
)
filtered_boxes = [boxes[i] for i in kept_indices]
else:
filtered_points = points
filtered_boxes = boxes
points = filtered_points
boxes = filtered_boxes
# Visualize results
img_w, img_h = image.size
img_draw = image.copy()
draw = ImageDraw.Draw(img_draw)
for point in points:
x = point[0] * img_w
y = point[1] * img_h
draw.ellipse(
[x - point_radius, y - point_radius, x + point_radius, y + point_radius],
fill=point_color
)
count = len(points)
# Save uploaded data for collection
save_uploaded_data(image, instruction, pos_caption, neg_caption, count, points)
return img_draw, f"Count: {count}", parsed_info
def count_objects_manual(image, pos_caption, neg_caption, box_threshold, point_radius, point_color):
"""Manual mode: directly use provided positive and negative captions."""
global model, processor, device
# Validate image format
is_valid, error_msg = validate_image(image)
if not is_valid:
raise gr.Error(error_msg)
if model is None:
load_model()
# Preprocess image
image = preprocess_image(image)
if pos_caption and not pos_caption.endswith('.'):
pos_caption = pos_caption + '.'
if neg_caption and not neg_caption.endswith('.'):
neg_caption = neg_caption + '.'
if not neg_caption or neg_caption.strip() == '':
neg_caption = "None."
parsed_info = f"Positive: {pos_caption}\nNegative: {neg_caption}"
pos_inputs = processor(
images=image,
text=pos_caption,
return_tensors="pt",
padding=True
)
pos_inputs = pos_inputs.to(device)
pos_inputs['pixel_values'] = pos_inputs['pixel_values'].to(torch.bfloat16)
use_neg = bool(neg_caption and neg_caption.strip() and neg_caption != '.' and neg_caption != 'None.')
if not use_neg:
neg_caption = "None."
neg_inputs = processor(
images=image,
text=neg_caption,
return_tensors="pt",
padding=True
)
neg_inputs = {k: v.to(device) for k, v in neg_inputs.items()}
neg_inputs['pixel_values'] = neg_inputs['pixel_values'].to(torch.bfloat16)
pos_inputs['neg_token_type_ids'] = neg_inputs['token_type_ids']
pos_inputs['neg_attention_mask'] = neg_inputs['attention_mask']
pos_inputs['neg_pixel_mask'] = neg_inputs['pixel_mask']
pos_inputs['neg_pixel_values'] = neg_inputs['pixel_values']
pos_inputs['neg_input_ids'] = neg_inputs['input_ids']
pos_inputs['use_neg'] = True
with torch.no_grad():
outputs = model(**pos_inputs)
outputs["pred_points"] = outputs["pred_boxes"][:, :, :2]
outputs["pred_logits"] = outputs["logits"]
threshold = box_threshold if box_threshold > 0 else model.box_threshold
pos_queries = outputs["pos_queries"].squeeze(0).float()
neg_queries = outputs["neg_queries"].squeeze(0).float()
pos_queries = pos_queries[-1].squeeze(0)
neg_queries = neg_queries[-1].squeeze(0)
pos_queries = pos_queries.unsqueeze(0)
neg_queries = neg_queries.unsqueeze(0)
results = post_process_grounded_object_detection_with_queries(outputs, pos_queries, box_threshold=threshold)[0]
boxes = results["boxes"]
boxes = [box.tolist() for box in boxes]
points = [[box[0], box[1]] for box in boxes]
neg_points = []
neg_results = None
if "neg_pred_boxes" in outputs and "neg_logits" in outputs:
neg_outputs = outputs.copy()
neg_outputs["pred_boxes"] = outputs["neg_pred_boxes"]
neg_outputs["logits"] = outputs["neg_logits"]
neg_outputs["pred_points"] = outputs["neg_pred_boxes"][:, :, :2]
neg_outputs["pred_logits"] = outputs["neg_logits"]
neg_results = post_process_grounded_object_detection_with_queries(neg_outputs, neg_queries, box_threshold=0.5)[0]
neg_boxes = neg_results["boxes"]
neg_boxes = [box.tolist() for box in neg_boxes]
neg_points = [[box[0], box[1]] for box in neg_boxes]
pos_queries_np = results["queries"].cpu().numpy()
neg_queries_np = neg_results["queries"].cpu().numpy() if neg_results else np.array([])
img_size = image.size
if len(neg_points) > 0 and len(neg_queries_np) > 0:
filtered_points, kept_indices, suppression_info = discriminative_point_suppression(
points,
neg_points,
pos_queries_np,
neg_queries_np,
image_size=img_size,
pixel_threshold=5,
similarity_threshold=0.3,
)
filtered_boxes = [boxes[i] for i in kept_indices]
else:
filtered_points = points
filtered_boxes = boxes
points = filtered_points
boxes = filtered_boxes
img_w, img_h = image.size
img_draw = image.copy()
draw = ImageDraw.Draw(img_draw)
for point in points:
x = point[0] * img_w
y = point[1] * img_h
draw.ellipse(
[x - point_radius, y - point_radius, x + point_radius, y + point_radius],
fill=point_color
)
count = len(points)
instruction = f"[MANUAL] pos: {pos_caption} | neg: {neg_caption}"
save_uploaded_data(image, instruction, pos_caption, neg_caption, count, points)
return img_draw, f"Count: {count}", parsed_info
def create_demo():
with gr.Blocks(title="CountEx: Discriminative Visual Counting") as demo:
gr.Markdown("""
# CountEx: Fine-Grained Counting via Exemplars and Exclusion
Count specific objects in images using text prompts with exclusion capability.
""")
current_mode = gr.State(value="natural_language")
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(type="filepath", label="Input Image (JPG, PNG only)")
with gr.Tabs() as input_tabs:
with gr.TabItem("Natural Language", id=0) as tab_nl:
instruction = gr.Textbox(
label="Counting Instruction",
placeholder="e.g., Count apples, not green apples",
value="Count apples, not green apples",
lines=2
)
gr.Markdown("""
**Examples:**
- "Count apples, not green apples"
- "Count red and black beans, exclude white beans"
- "Count people, not women"
""")
with gr.TabItem("Manual Input", id=1) as tab_manual:
pos_caption = gr.Textbox(
label="Positive Prompt (objects to count)",
placeholder="e.g., apple",
value="apple."
)
neg_caption = gr.Textbox(
label="Negative Prompt (objects to exclude)",
placeholder="e.g., green apple",
value="None."
)
submit_btn = gr.Button("Count Objects", variant="primary", size="lg")
with gr.Accordion("Advanced Settings", open=False):
box_threshold = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.42,
step=0.01,
label="Threshold"
)
point_radius = gr.Slider(
minimum=1,
maximum=20,
value=5,
step=1,
label="Point Radius"
)
point_color = gr.Dropdown(
choices=["blue", "red", "green", "yellow", "cyan", "magenta", "white", "orange"],
value="blue",
label="Point Color"
)
with gr.Column(scale=1):
output_image = gr.Image(type="pil", label="Result")
count_output = gr.Textbox(label="Count Result")
parsed_output = gr.Textbox(label="Parsed Captions", lines=2)
gr.Markdown("### Examples (Natural Language)")
gr.Examples(
examples=[
["examples/apples.png", "Count apples, not green apples"],
["examples/apples.png", "Count apples, exclude red apples"],
["examples/apple.jpg", "Count green apples"],
["examples/apple.jpg", "Count apples, exclude green apples"],
["examples/apple.jpg", "Count apples, exclude red apples"],
["examples/blue_straw_peach.png", "Count blueberries"],
["examples/blue_straw_peach.png", "Count leaf"],
["examples/blue_straw_peach.png", "Count blueberries and cherry"],
["examples/blue_straw_peach.png", "Count blueberries and cherry and strawberry"],
["examples/black_beans.jpg", "Count black beans and soy beans"],
["examples/black_beans.jpg", "Count beans"],
["examples/black_beans.jpg", "Count pig"],
["examples/candy.jpg", "Count brown coffee candy, exclude black coffee candy"],
["examples/candy.jpg", "Count candy"],
["examples/candy.jpg", "Count brown coffee candy and black coffee candy"],
["examples/candy.jpg", "Count sausage"],
["examples/strawberry.jpg", "Count blueberries and strawberry"],
["examples/strawberry.jpg", "Count book"],
["examples/strawberry2.jpg", "Count blueberries, exclude strawberry"],
["examples/women.jpg", "Count people, not women"],
["examples/women.jpg", "Count people, not man"],
["examples/boat-1.jpg", "Count boats, exclude blue boats"],
["examples/boat-1.jpg", "Count boats, exclude red boats"],
],
inputs=[input_image, instruction],
outputs=[output_image, count_output, parsed_output],
fn=lambda img, instr: count_objects(img, instr, 0.42, 5, "blue"),
cache_examples=False,
)
def set_mode_nl():
return "natural_language"
def set_mode_manual():
return "manual"
tab_nl.select(fn=set_mode_nl, outputs=[current_mode])
tab_manual.select(fn=set_mode_manual, outputs=[current_mode])
def handle_submit(mode, image, instr, pos_cap, neg_cap, threshold, radius, color):
if mode == "natural_language":
return count_objects(image, instr, threshold, radius, color)
else:
return count_objects_manual(image, pos_cap, neg_cap, threshold, radius, color)
submit_btn.click(
fn=handle_submit,
inputs=[current_mode, input_image, instruction, pos_caption, neg_caption,
box_threshold, point_radius, point_color],
outputs=[output_image, count_output, parsed_output]
)
return demo
if __name__ == "__main__":
# Initialize HuggingFace API
print("Initializing HuggingFace API...")
init_hf_api()
# Load model at startup
print("Loading model...")
load_model()
print("Model loaded!")
# Create and launch demo
demo = create_demo()
demo.launch()