Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import gradio as gr | |
| import torch | |
| from PIL import Image, ImageDraw | |
| from transformers import GroundingDinoProcessor | |
| from hf_model import CountEX | |
| from utils import post_process_grounded_object_detection, post_process_grounded_object_detection_with_queries | |
| import google.generativeai as genai | |
| from datetime import datetime | |
| import csv | |
| from pathlib import Path | |
| import uuid | |
| import io | |
| # Try to import HEIC support | |
| try: | |
| from pillow_heif import register_heif_opener | |
| register_heif_opener() | |
| HEIC_SUPPORTED = True | |
| except ImportError: | |
| HEIC_SUPPORTED = False | |
| print("Warning: pillow-heif not installed. HEIC images will not be supported.") | |
| # Try to import HuggingFace Hub | |
| try: | |
| from huggingface_hub import HfApi | |
| HF_HUB_AVAILABLE = True | |
| except ImportError: | |
| HF_HUB_AVAILABLE = False | |
| print("Warning: huggingface_hub not installed.") | |
| # Global variables for model and processor | |
| model = None | |
| processor = None | |
| device = None | |
| hf_api = None | |
| # Data collection directory (local fallback) | |
| DATA_LOG_DIR = Path("uploaded_data") | |
| DATA_LOG_DIR.mkdir(exist_ok=True) | |
| IMAGES_DIR = DATA_LOG_DIR / "images" | |
| IMAGES_DIR.mkdir(exist_ok=True) | |
| DATA_LOG_FILE = DATA_LOG_DIR / "prompts_log.csv" | |
| # HuggingFace Dataset repo for data collection | |
| HF_DATASET_REPO = os.environ.get("HF_DATASET_REPO", "BBVisual/CountEx_UserData") | |
| # Initialize CSV log file with headers if it doesn't exist | |
| if not DATA_LOG_FILE.exists(): | |
| with open(DATA_LOG_FILE, "w", newline="") as f: | |
| writer = csv.writer(f) | |
| writer.writerow(["timestamp", "image_filename", "instruction", "pos_caption", "neg_caption", "count"]) | |
| # Image processing constants | |
| MAX_IMAGE_SIZE = 1333 # Max dimension (width or height) | |
| ALLOWED_EXTENSIONS = {".jpg", ".jpeg", ".png"} | |
| gemini_api_key = os.environ.get("GEMINI_API_KEY") | |
| # Configure Gemini | |
| genai.configure(api_key=gemini_api_key) | |
| gemini_model = genai.GenerativeModel("gemini-2.0-flash") | |
| PARSING_PROMPT = """Parse sentences of the form "Count A, not B" into two lists—A (include) and B (exclude)—splitting on "and", "or", and commas, and reattaching shared head nouns (e.g., "red and black beans" → "red beans", "black beans"). | |
| Rules: | |
| - Remove from B items that are equivalent to items in A (synonyms/variants/abbreviations/regional terms) | |
| - Keep B items that are more specific than A (for fine-grained exclusion) | |
| - If B is more general than A but shares the head noun, remove B (contradictory) | |
| Case 1 — Different head nouns → Keep B | |
| Example 1: Count green apples and red beans, not yellow screws and white rice → A: ["green apples", "red beans"], B: ["yellow screws", "white rice"] | |
| Example 2: Count black beans, not poker chips or nails → A: ["black beans"], B: ["poker chips", "nails"] | |
| Case 2 — Equivalent items → Remove from B | |
| Example 1: Count fries and TV, not chips and television → A: ["fries", "TV"], B: [] | |
| Example 2: Count garbanzo beans and couch, not chickpeas and sofa → A: ["garbanzo beans", "couch"], B: [] | |
| Case 3 — B more specific than A → Keep B (for fine-grained exclusion) | |
| Example 1: Count apples and beans, not green apples and black beans → A: ["apples", "beans"], B: ["green apples", "black beans"] | |
| Example 2: Count beans, not white beans or yellow beans → A: ["beans"], B: ["white beans", "yellow beans"] | |
| Example 3: Count people, not women → A: ["people"], B: ["women"] | |
| Case 4 — B more general than A → Remove B (contradictory) | |
| Example 1: Count green apples, not apples → A: ["green apples"], B: [] | |
| Example 2: Count red beans and green apples, not beans and apples → A: ["red beans", "green apples"], B: [] | |
| User instruction: {instruction} | |
| Respond ONLY with a JSON object in this exact format, no other text: | |
| {{"A": ["item1", "item2"], "B": ["item3"]}} | |
| """ | |
| def init_hf_api(): | |
| """Initialize HuggingFace API for dataset upload.""" | |
| global hf_api | |
| if not HF_HUB_AVAILABLE: | |
| print("HuggingFace Hub not available") | |
| return None | |
| try: | |
| hf_token = os.environ.get("HF_WRITTE_TOKEN") | |
| if not hf_token: | |
| print("HF_WRITTE_TOKEN not set, data collection disabled") | |
| return None | |
| hf_api = HfApi(token=hf_token) | |
| print(f"HuggingFace API initialized. Dataset repo: {HF_DATASET_REPO}") | |
| return hf_api | |
| except Exception as e: | |
| print(f"Error initializing HuggingFace API: {e}") | |
| return None | |
| def upload_to_hf_dataset(image_bytes, image_filename, data_dict): | |
| """Upload image and metadata to HuggingFace Dataset.""" | |
| global hf_api | |
| if not hf_api: | |
| return False | |
| try: | |
| hf_token = os.environ.get("HF_WRITTE_TOKEN") | |
| # Upload image | |
| hf_api.upload_file( | |
| path_or_fileobj=io.BytesIO(image_bytes), | |
| path_in_repo=f"images/{image_filename}", | |
| repo_id=HF_DATASET_REPO, | |
| repo_type="dataset", | |
| token=hf_token | |
| ) | |
| # Upload metadata as individual JSON file (avoids race conditions) | |
| json_filename = image_filename.replace('.jpg', '.json') | |
| json_content = json.dumps(data_dict, indent=2) | |
| hf_api.upload_file( | |
| path_or_fileobj=io.BytesIO(json_content.encode('utf-8')), | |
| path_in_repo=f"metadata/{json_filename}", | |
| repo_id=HF_DATASET_REPO, | |
| repo_type="dataset", | |
| token=hf_token | |
| ) | |
| return True | |
| except Exception as e: | |
| print(f"Error uploading to HuggingFace Dataset: {e}") | |
| return False | |
| def save_uploaded_data(image, instruction, pos_caption, neg_caption, count, points=None): | |
| """ | |
| Save uploaded image and prompt data for collection. | |
| Tries HuggingFace Dataset first, falls back to local storage. | |
| """ | |
| global hf_api | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| unique_id = str(uuid.uuid4())[:8] | |
| image_filename = f"{timestamp}_{unique_id}.jpg" | |
| # Prepare image bytes | |
| img_buffer = io.BytesIO() | |
| image.save(img_buffer, format='JPEG', quality=95) | |
| img_bytes = img_buffer.getvalue() | |
| # Data as dict (for JSON) | |
| data_dict = { | |
| "timestamp": timestamp, | |
| "image_filename": image_filename, | |
| "instruction": instruction, | |
| "pos_caption": pos_caption, | |
| "neg_caption": neg_caption, | |
| "count": count, | |
| "points": points if points else [] # normalized coordinates (0-1) | |
| } | |
| # Try HuggingFace Dataset first | |
| if hf_api: | |
| try: | |
| if upload_to_hf_dataset(img_bytes, image_filename, data_dict): | |
| print(f"Saved to HuggingFace Dataset: {image_filename}") | |
| return | |
| except Exception as e: | |
| print(f"HuggingFace upload failed, falling back to local: {e}") | |
| # Fallback to local storage | |
| try: | |
| image_path = IMAGES_DIR / image_filename | |
| image.save(image_path, "JPEG", quality=95) | |
| # Also save as JSON locally | |
| json_path = DATA_LOG_DIR / "metadata" | |
| json_path.mkdir(exist_ok=True) | |
| with open(json_path / image_filename.replace('.jpg', '.json'), 'w') as f: | |
| json.dump(data_dict, f, indent=2) | |
| # Also append to CSV for backward compatibility | |
| with open(DATA_LOG_FILE, "a", newline="") as f: | |
| writer = csv.writer(f) | |
| writer.writerow([timestamp, image_filename, instruction, pos_caption, neg_caption, count]) | |
| print(f"Saved locally: {image_filename}") | |
| except Exception as e: | |
| print(f"Error saving data: {e}") | |
| def validate_image(image): | |
| """ | |
| Validate uploaded image format. | |
| Returns (is_valid, error_message) | |
| """ | |
| if image is None: | |
| return False, "Error: Please upload an image." | |
| # Get file extension | |
| if isinstance(image, str): | |
| ext = os.path.splitext(image)[1].lower() | |
| if ext and ext not in ALLOWED_EXTENSIONS: | |
| return False, f"Error: Unsupported format '{ext}'. Only JPG and PNG are supported." | |
| return True, None | |
| def preprocess_image(image): | |
| """ | |
| Preprocess uploaded image: convert format and resize if needed. | |
| """ | |
| # Handle file path input | |
| if isinstance(image, str): | |
| image = Image.open(image) | |
| # Convert to RGB (handles RGBA, P mode, etc.) | |
| if image.mode != "RGB": | |
| if image.mode in ("RGBA", "LA", "P"): | |
| background = Image.new("RGB", image.size, (255, 255, 255)) | |
| if image.mode == "P": | |
| image = image.convert("RGBA") | |
| background.paste(image, mask=image.split()[-1] if image.mode == "RGBA" else None) | |
| image = background | |
| else: | |
| image = image.convert("RGB") | |
| # Resize if image is too large | |
| width, height = image.size | |
| if max(width, height) > MAX_IMAGE_SIZE: | |
| scale = MAX_IMAGE_SIZE / max(width, height) | |
| new_width = int(width * scale) | |
| new_height = int(height * scale) | |
| image = image.resize((new_width, new_height), Image.Resampling.LANCZOS) | |
| print(f"Resized image from {width}x{height} to {new_width}x{new_height}") | |
| return image | |
| def parse_counting_instruction(instruction: str) -> tuple[str, str]: | |
| """ | |
| Parse natural language counting instruction using Gemini 2.0 Flash. | |
| """ | |
| try: | |
| prompt = PARSING_PROMPT.format(instruction=instruction) | |
| response = gemini_model.generate_content(prompt) | |
| response_text = response.text.strip() | |
| # Clean up response - remove markdown code blocks if present | |
| if response_text.startswith("```"): | |
| response_text = response_text.split("```")[1] | |
| if response_text.startswith("json"): | |
| response_text = response_text[4:] | |
| response_text = response_text.strip() | |
| result = json.loads(response_text) | |
| # Convert lists to caption strings | |
| pos_items = result.get("A", []) | |
| neg_items = result.get("B", []) | |
| # Join items with " and " and add period | |
| pos_caption = " and ".join(pos_items) + "." if pos_items else "" | |
| neg_caption = " and ".join(neg_items) + "." if neg_items else "None." | |
| return pos_caption, neg_caption | |
| except Exception as e: | |
| print(f"Error parsing instruction: {e}") | |
| return instruction.strip() + ".", "None." | |
| def load_model(): | |
| """Load model and processor once at startup""" | |
| global model, processor, device | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model_id = "yifehuang97/CountEx_KC_aug_v3_12140136_v2" | |
| model = CountEX.from_pretrained(model_id, token=os.environ.get("HF_TOKEN")) | |
| model = model.to(torch.bfloat16) | |
| model = model.to(device) | |
| model.eval() | |
| processor_id = "fushh7/llmdet_swin_tiny_hf" | |
| processor = GroundingDinoProcessor.from_pretrained(processor_id) | |
| return model, processor, device | |
| import numpy as np | |
| def discriminative_point_suppression( | |
| points, | |
| neg_points, | |
| pos_queries, | |
| neg_queries, | |
| image_size, | |
| pixel_threshold=5, | |
| similarity_threshold=0.3, | |
| ): | |
| """Discriminative Point Suppression (DPS)""" | |
| if not neg_points or not points: | |
| return points, list(range(len(points))), {} | |
| width, height = image_size | |
| N, M = len(points), len(neg_points) | |
| points_arr = np.array(points) * np.array([width, height]) | |
| neg_points_arr = np.array(neg_points) * np.array([width, height]) | |
| spatial_dist = np.linalg.norm( | |
| points_arr[:, None, :] - neg_points_arr[None, :, :], axis=-1 | |
| ) | |
| nearest_neg_idx = spatial_dist.argmin(axis=1) | |
| nearest_neg_dist = spatial_dist.min(axis=1) | |
| spatially_close = nearest_neg_dist < pixel_threshold | |
| pos_q = pos_queries / (np.linalg.norm(pos_queries, axis=-1, keepdims=True) + 1e-8) | |
| neg_q = neg_queries / (np.linalg.norm(neg_queries, axis=-1, keepdims=True) + 1e-8) | |
| matched_neg_q = neg_q[nearest_neg_idx] | |
| query_sim = (pos_q * matched_neg_q).sum(axis=-1) | |
| semantically_similar = query_sim > similarity_threshold | |
| should_suppress = spatially_close & semantically_similar | |
| keep_mask = ~should_suppress | |
| filtered_points = np.array(points)[keep_mask].tolist() | |
| filtered_indices = np.where(keep_mask)[0].tolist() | |
| suppression_info = { | |
| "nearest_neg_idx": nearest_neg_idx.tolist(), | |
| "nearest_neg_dist": nearest_neg_dist.tolist(), | |
| "query_similarity": query_sim.tolist(), | |
| "spatially_close": spatially_close.tolist(), | |
| "semantically_similar": semantically_similar.tolist(), | |
| "suppressed_indices": np.where(should_suppress)[0].tolist(), | |
| } | |
| return filtered_points, filtered_indices, suppression_info | |
| def count_objects(image, instruction, box_threshold, point_radius, point_color): | |
| """Main inference function for counting objects""" | |
| global model, processor, device | |
| # Validate image format | |
| is_valid, error_msg = validate_image(image) | |
| if not is_valid: | |
| raise gr.Error(error_msg) | |
| if model is None: | |
| load_model() | |
| # Preprocess image | |
| image = preprocess_image(image) | |
| # Parse instruction using Gemini | |
| pos_caption, neg_caption = parse_counting_instruction(instruction) | |
| parsed_info = f"Positive: {pos_caption}\nNegative: {neg_caption}" | |
| # Process positive caption | |
| pos_inputs = processor( | |
| images=image, | |
| text=pos_caption, | |
| return_tensors="pt", | |
| padding=True | |
| ) | |
| pos_inputs = pos_inputs.to(device) | |
| pos_inputs['pixel_values'] = pos_inputs['pixel_values'].to(torch.bfloat16) | |
| # Process negative caption | |
| use_neg = bool(neg_caption and neg_caption.strip() and neg_caption != '.' and neg_caption != 'None.') | |
| if not use_neg: | |
| neg_caption = "None." | |
| neg_inputs = processor( | |
| images=image, | |
| text=neg_caption, | |
| return_tensors="pt", | |
| padding=True | |
| ) | |
| neg_inputs = {k: v.to(device) for k, v in neg_inputs.items()} | |
| neg_inputs['pixel_values'] = neg_inputs['pixel_values'].to(torch.bfloat16) | |
| pos_inputs['neg_token_type_ids'] = neg_inputs['token_type_ids'] | |
| pos_inputs['neg_attention_mask'] = neg_inputs['attention_mask'] | |
| pos_inputs['neg_pixel_mask'] = neg_inputs['pixel_mask'] | |
| pos_inputs['neg_pixel_values'] = neg_inputs['pixel_values'] | |
| pos_inputs['neg_input_ids'] = neg_inputs['input_ids'] | |
| pos_inputs['use_neg'] = True | |
| # Run inference | |
| with torch.no_grad(): | |
| outputs = model(**pos_inputs) | |
| outputs["pred_points"] = outputs["pred_boxes"][:, :, :2] | |
| outputs["pred_logits"] = outputs["logits"] | |
| threshold = box_threshold if box_threshold > 0 else model.box_threshold | |
| pos_queries = outputs["pos_queries"].squeeze(0).float() | |
| neg_queries = outputs["neg_queries"].squeeze(0).float() | |
| pos_queries = pos_queries[-1].squeeze(0) | |
| neg_queries = neg_queries[-1].squeeze(0) | |
| pos_queries = pos_queries.unsqueeze(0) | |
| neg_queries = neg_queries.unsqueeze(0) | |
| results = post_process_grounded_object_detection_with_queries(outputs, pos_queries, box_threshold=threshold)[0] | |
| boxes = results["boxes"] | |
| boxes = [box.tolist() for box in boxes] | |
| points = [[box[0], box[1]] for box in boxes] | |
| neg_points = [] | |
| neg_results = None | |
| if "neg_pred_boxes" in outputs and "neg_logits" in outputs: | |
| neg_outputs = outputs.copy() | |
| neg_outputs["pred_boxes"] = outputs["neg_pred_boxes"] | |
| neg_outputs["logits"] = outputs["neg_logits"] | |
| neg_outputs["pred_points"] = outputs["neg_pred_boxes"][:, :, :2] | |
| neg_outputs["pred_logits"] = outputs["neg_logits"] | |
| neg_results = post_process_grounded_object_detection_with_queries(neg_outputs, neg_queries, box_threshold=0.5)[0] | |
| neg_boxes = neg_results["boxes"] | |
| neg_boxes = [box.tolist() for box in neg_boxes] | |
| neg_points = [[box[0], box[1]] for box in neg_boxes] | |
| pos_queries_np = results["queries"].cpu().numpy() | |
| neg_queries_np = neg_results["queries"].cpu().numpy() if neg_results else np.array([]) | |
| img_size = image.size | |
| if len(neg_points) > 0 and len(neg_queries_np) > 0: | |
| filtered_points, kept_indices, suppression_info = discriminative_point_suppression( | |
| points, | |
| neg_points, | |
| pos_queries_np, | |
| neg_queries_np, | |
| image_size=img_size, | |
| pixel_threshold=5, | |
| similarity_threshold=0.3, | |
| ) | |
| filtered_boxes = [boxes[i] for i in kept_indices] | |
| else: | |
| filtered_points = points | |
| filtered_boxes = boxes | |
| points = filtered_points | |
| boxes = filtered_boxes | |
| # Visualize results | |
| img_w, img_h = image.size | |
| img_draw = image.copy() | |
| draw = ImageDraw.Draw(img_draw) | |
| for point in points: | |
| x = point[0] * img_w | |
| y = point[1] * img_h | |
| draw.ellipse( | |
| [x - point_radius, y - point_radius, x + point_radius, y + point_radius], | |
| fill=point_color | |
| ) | |
| count = len(points) | |
| # Save uploaded data for collection | |
| save_uploaded_data(image, instruction, pos_caption, neg_caption, count, points) | |
| return img_draw, f"Count: {count}", parsed_info | |
| def count_objects_manual(image, pos_caption, neg_caption, box_threshold, point_radius, point_color): | |
| """Manual mode: directly use provided positive and negative captions.""" | |
| global model, processor, device | |
| # Validate image format | |
| is_valid, error_msg = validate_image(image) | |
| if not is_valid: | |
| raise gr.Error(error_msg) | |
| if model is None: | |
| load_model() | |
| # Preprocess image | |
| image = preprocess_image(image) | |
| if pos_caption and not pos_caption.endswith('.'): | |
| pos_caption = pos_caption + '.' | |
| if neg_caption and not neg_caption.endswith('.'): | |
| neg_caption = neg_caption + '.' | |
| if not neg_caption or neg_caption.strip() == '': | |
| neg_caption = "None." | |
| parsed_info = f"Positive: {pos_caption}\nNegative: {neg_caption}" | |
| pos_inputs = processor( | |
| images=image, | |
| text=pos_caption, | |
| return_tensors="pt", | |
| padding=True | |
| ) | |
| pos_inputs = pos_inputs.to(device) | |
| pos_inputs['pixel_values'] = pos_inputs['pixel_values'].to(torch.bfloat16) | |
| use_neg = bool(neg_caption and neg_caption.strip() and neg_caption != '.' and neg_caption != 'None.') | |
| if not use_neg: | |
| neg_caption = "None." | |
| neg_inputs = processor( | |
| images=image, | |
| text=neg_caption, | |
| return_tensors="pt", | |
| padding=True | |
| ) | |
| neg_inputs = {k: v.to(device) for k, v in neg_inputs.items()} | |
| neg_inputs['pixel_values'] = neg_inputs['pixel_values'].to(torch.bfloat16) | |
| pos_inputs['neg_token_type_ids'] = neg_inputs['token_type_ids'] | |
| pos_inputs['neg_attention_mask'] = neg_inputs['attention_mask'] | |
| pos_inputs['neg_pixel_mask'] = neg_inputs['pixel_mask'] | |
| pos_inputs['neg_pixel_values'] = neg_inputs['pixel_values'] | |
| pos_inputs['neg_input_ids'] = neg_inputs['input_ids'] | |
| pos_inputs['use_neg'] = True | |
| with torch.no_grad(): | |
| outputs = model(**pos_inputs) | |
| outputs["pred_points"] = outputs["pred_boxes"][:, :, :2] | |
| outputs["pred_logits"] = outputs["logits"] | |
| threshold = box_threshold if box_threshold > 0 else model.box_threshold | |
| pos_queries = outputs["pos_queries"].squeeze(0).float() | |
| neg_queries = outputs["neg_queries"].squeeze(0).float() | |
| pos_queries = pos_queries[-1].squeeze(0) | |
| neg_queries = neg_queries[-1].squeeze(0) | |
| pos_queries = pos_queries.unsqueeze(0) | |
| neg_queries = neg_queries.unsqueeze(0) | |
| results = post_process_grounded_object_detection_with_queries(outputs, pos_queries, box_threshold=threshold)[0] | |
| boxes = results["boxes"] | |
| boxes = [box.tolist() for box in boxes] | |
| points = [[box[0], box[1]] for box in boxes] | |
| neg_points = [] | |
| neg_results = None | |
| if "neg_pred_boxes" in outputs and "neg_logits" in outputs: | |
| neg_outputs = outputs.copy() | |
| neg_outputs["pred_boxes"] = outputs["neg_pred_boxes"] | |
| neg_outputs["logits"] = outputs["neg_logits"] | |
| neg_outputs["pred_points"] = outputs["neg_pred_boxes"][:, :, :2] | |
| neg_outputs["pred_logits"] = outputs["neg_logits"] | |
| neg_results = post_process_grounded_object_detection_with_queries(neg_outputs, neg_queries, box_threshold=0.5)[0] | |
| neg_boxes = neg_results["boxes"] | |
| neg_boxes = [box.tolist() for box in neg_boxes] | |
| neg_points = [[box[0], box[1]] for box in neg_boxes] | |
| pos_queries_np = results["queries"].cpu().numpy() | |
| neg_queries_np = neg_results["queries"].cpu().numpy() if neg_results else np.array([]) | |
| img_size = image.size | |
| if len(neg_points) > 0 and len(neg_queries_np) > 0: | |
| filtered_points, kept_indices, suppression_info = discriminative_point_suppression( | |
| points, | |
| neg_points, | |
| pos_queries_np, | |
| neg_queries_np, | |
| image_size=img_size, | |
| pixel_threshold=5, | |
| similarity_threshold=0.3, | |
| ) | |
| filtered_boxes = [boxes[i] for i in kept_indices] | |
| else: | |
| filtered_points = points | |
| filtered_boxes = boxes | |
| points = filtered_points | |
| boxes = filtered_boxes | |
| img_w, img_h = image.size | |
| img_draw = image.copy() | |
| draw = ImageDraw.Draw(img_draw) | |
| for point in points: | |
| x = point[0] * img_w | |
| y = point[1] * img_h | |
| draw.ellipse( | |
| [x - point_radius, y - point_radius, x + point_radius, y + point_radius], | |
| fill=point_color | |
| ) | |
| count = len(points) | |
| instruction = f"[MANUAL] pos: {pos_caption} | neg: {neg_caption}" | |
| save_uploaded_data(image, instruction, pos_caption, neg_caption, count, points) | |
| return img_draw, f"Count: {count}", parsed_info | |
| def create_demo(): | |
| with gr.Blocks(title="CountEx: Discriminative Visual Counting") as demo: | |
| gr.Markdown(""" | |
| # CountEx: Fine-Grained Counting via Exemplars and Exclusion | |
| Count specific objects in images using text prompts with exclusion capability. | |
| """) | |
| current_mode = gr.State(value="natural_language") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| input_image = gr.Image(type="filepath", label="Input Image (JPG, PNG only)") | |
| with gr.Tabs() as input_tabs: | |
| with gr.TabItem("Natural Language", id=0) as tab_nl: | |
| instruction = gr.Textbox( | |
| label="Counting Instruction", | |
| placeholder="e.g., Count apples, not green apples", | |
| value="Count apples, not green apples", | |
| lines=2 | |
| ) | |
| gr.Markdown(""" | |
| **Examples:** | |
| - "Count apples, not green apples" | |
| - "Count red and black beans, exclude white beans" | |
| - "Count people, not women" | |
| """) | |
| with gr.TabItem("Manual Input", id=1) as tab_manual: | |
| pos_caption = gr.Textbox( | |
| label="Positive Prompt (objects to count)", | |
| placeholder="e.g., apple", | |
| value="apple." | |
| ) | |
| neg_caption = gr.Textbox( | |
| label="Negative Prompt (objects to exclude)", | |
| placeholder="e.g., green apple", | |
| value="None." | |
| ) | |
| submit_btn = gr.Button("Count Objects", variant="primary", size="lg") | |
| with gr.Accordion("Advanced Settings", open=False): | |
| box_threshold = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.42, | |
| step=0.01, | |
| label="Threshold" | |
| ) | |
| point_radius = gr.Slider( | |
| minimum=1, | |
| maximum=20, | |
| value=5, | |
| step=1, | |
| label="Point Radius" | |
| ) | |
| point_color = gr.Dropdown( | |
| choices=["blue", "red", "green", "yellow", "cyan", "magenta", "white", "orange"], | |
| value="blue", | |
| label="Point Color" | |
| ) | |
| with gr.Column(scale=1): | |
| output_image = gr.Image(type="pil", label="Result") | |
| count_output = gr.Textbox(label="Count Result") | |
| parsed_output = gr.Textbox(label="Parsed Captions", lines=2) | |
| gr.Markdown("### Examples (Natural Language)") | |
| gr.Examples( | |
| examples=[ | |
| ["examples/apples.png", "Count apples, not green apples"], | |
| ["examples/apples.png", "Count apples, exclude red apples"], | |
| ["examples/apple.jpg", "Count green apples"], | |
| ["examples/apple.jpg", "Count apples, exclude green apples"], | |
| ["examples/apple.jpg", "Count apples, exclude red apples"], | |
| ["examples/blue_straw_peach.png", "Count blueberries"], | |
| ["examples/blue_straw_peach.png", "Count leaf"], | |
| ["examples/blue_straw_peach.png", "Count blueberries and cherry"], | |
| ["examples/blue_straw_peach.png", "Count blueberries and cherry and strawberry"], | |
| ["examples/black_beans.jpg", "Count black beans and soy beans"], | |
| ["examples/black_beans.jpg", "Count beans"], | |
| ["examples/black_beans.jpg", "Count pig"], | |
| ["examples/candy.jpg", "Count brown coffee candy, exclude black coffee candy"], | |
| ["examples/candy.jpg", "Count candy"], | |
| ["examples/candy.jpg", "Count brown coffee candy and black coffee candy"], | |
| ["examples/candy.jpg", "Count sausage"], | |
| ["examples/strawberry.jpg", "Count blueberries and strawberry"], | |
| ["examples/strawberry.jpg", "Count book"], | |
| ["examples/strawberry2.jpg", "Count blueberries, exclude strawberry"], | |
| ["examples/women.jpg", "Count people, not women"], | |
| ["examples/women.jpg", "Count people, not man"], | |
| ["examples/boat-1.jpg", "Count boats, exclude blue boats"], | |
| ["examples/boat-1.jpg", "Count boats, exclude red boats"], | |
| ], | |
| inputs=[input_image, instruction], | |
| outputs=[output_image, count_output, parsed_output], | |
| fn=lambda img, instr: count_objects(img, instr, 0.42, 5, "blue"), | |
| cache_examples=False, | |
| ) | |
| def set_mode_nl(): | |
| return "natural_language" | |
| def set_mode_manual(): | |
| return "manual" | |
| tab_nl.select(fn=set_mode_nl, outputs=[current_mode]) | |
| tab_manual.select(fn=set_mode_manual, outputs=[current_mode]) | |
| def handle_submit(mode, image, instr, pos_cap, neg_cap, threshold, radius, color): | |
| if mode == "natural_language": | |
| return count_objects(image, instr, threshold, radius, color) | |
| else: | |
| return count_objects_manual(image, pos_cap, neg_cap, threshold, radius, color) | |
| submit_btn.click( | |
| fn=handle_submit, | |
| inputs=[current_mode, input_image, instruction, pos_caption, neg_caption, | |
| box_threshold, point_radius, point_color], | |
| outputs=[output_image, count_output, parsed_output] | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| # Initialize HuggingFace API | |
| print("Initializing HuggingFace API...") | |
| init_hf_api() | |
| # Load model at startup | |
| print("Loading model...") | |
| load_model() | |
| print("Model loaded!") | |
| # Create and launch demo | |
| demo = create_demo() | |
| demo.launch() |