import os import json import gradio as gr import torch from PIL import Image, ImageDraw from transformers import GroundingDinoProcessor from hf_model import CountEX from utils import post_process_grounded_object_detection, post_process_grounded_object_detection_with_queries import google.generativeai as genai from datetime import datetime import csv from pathlib import Path import uuid import io # Try to import HEIC support try: from pillow_heif import register_heif_opener register_heif_opener() HEIC_SUPPORTED = True except ImportError: HEIC_SUPPORTED = False print("Warning: pillow-heif not installed. HEIC images will not be supported.") # Try to import HuggingFace Hub try: from huggingface_hub import HfApi HF_HUB_AVAILABLE = True except ImportError: HF_HUB_AVAILABLE = False print("Warning: huggingface_hub not installed.") # Global variables for model and processor model = None processor = None device = None hf_api = None # Data collection directory (local fallback) DATA_LOG_DIR = Path("uploaded_data") DATA_LOG_DIR.mkdir(exist_ok=True) IMAGES_DIR = DATA_LOG_DIR / "images" IMAGES_DIR.mkdir(exist_ok=True) DATA_LOG_FILE = DATA_LOG_DIR / "prompts_log.csv" # HuggingFace Dataset repo for data collection HF_DATASET_REPO = os.environ.get("HF_DATASET_REPO", "BBVisual/CountEx_UserData") # Initialize CSV log file with headers if it doesn't exist if not DATA_LOG_FILE.exists(): with open(DATA_LOG_FILE, "w", newline="") as f: writer = csv.writer(f) writer.writerow(["timestamp", "image_filename", "instruction", "pos_caption", "neg_caption", "count"]) # Image processing constants MAX_IMAGE_SIZE = 1333 # Max dimension (width or height) ALLOWED_EXTENSIONS = {".jpg", ".jpeg", ".png"} gemini_api_key = os.environ.get("GEMINI_API_KEY") # Configure Gemini genai.configure(api_key=gemini_api_key) gemini_model = genai.GenerativeModel("gemini-2.0-flash") PARSING_PROMPT = """Parse sentences of the form "Count A, not B" into two lists—A (include) and B (exclude)—splitting on "and", "or", and commas, and reattaching shared head nouns (e.g., "red and black beans" → "red beans", "black beans"). Rules: - Remove from B items that are equivalent to items in A (synonyms/variants/abbreviations/regional terms) - Keep B items that are more specific than A (for fine-grained exclusion) - If B is more general than A but shares the head noun, remove B (contradictory) Case 1 — Different head nouns → Keep B Example 1: Count green apples and red beans, not yellow screws and white rice → A: ["green apples", "red beans"], B: ["yellow screws", "white rice"] Example 2: Count black beans, not poker chips or nails → A: ["black beans"], B: ["poker chips", "nails"] Case 2 — Equivalent items → Remove from B Example 1: Count fries and TV, not chips and television → A: ["fries", "TV"], B: [] Example 2: Count garbanzo beans and couch, not chickpeas and sofa → A: ["garbanzo beans", "couch"], B: [] Case 3 — B more specific than A → Keep B (for fine-grained exclusion) Example 1: Count apples and beans, not green apples and black beans → A: ["apples", "beans"], B: ["green apples", "black beans"] Example 2: Count beans, not white beans or yellow beans → A: ["beans"], B: ["white beans", "yellow beans"] Example 3: Count people, not women → A: ["people"], B: ["women"] Case 4 — B more general than A → Remove B (contradictory) Example 1: Count green apples, not apples → A: ["green apples"], B: [] Example 2: Count red beans and green apples, not beans and apples → A: ["red beans", "green apples"], B: [] User instruction: {instruction} Respond ONLY with a JSON object in this exact format, no other text: {{"A": ["item1", "item2"], "B": ["item3"]}} """ def init_hf_api(): """Initialize HuggingFace API for dataset upload.""" global hf_api if not HF_HUB_AVAILABLE: print("HuggingFace Hub not available") return None try: hf_token = os.environ.get("HF_WRITTE_TOKEN") if not hf_token: print("HF_WRITTE_TOKEN not set, data collection disabled") return None hf_api = HfApi(token=hf_token) print(f"HuggingFace API initialized. Dataset repo: {HF_DATASET_REPO}") return hf_api except Exception as e: print(f"Error initializing HuggingFace API: {e}") return None def upload_to_hf_dataset(image_bytes, image_filename, data_dict): """Upload image and metadata to HuggingFace Dataset.""" global hf_api if not hf_api: return False try: hf_token = os.environ.get("HF_WRITTE_TOKEN") # Upload image hf_api.upload_file( path_or_fileobj=io.BytesIO(image_bytes), path_in_repo=f"images/{image_filename}", repo_id=HF_DATASET_REPO, repo_type="dataset", token=hf_token ) # Upload metadata as individual JSON file (avoids race conditions) json_filename = image_filename.replace('.jpg', '.json') json_content = json.dumps(data_dict, indent=2) hf_api.upload_file( path_or_fileobj=io.BytesIO(json_content.encode('utf-8')), path_in_repo=f"metadata/{json_filename}", repo_id=HF_DATASET_REPO, repo_type="dataset", token=hf_token ) return True except Exception as e: print(f"Error uploading to HuggingFace Dataset: {e}") return False def save_uploaded_data(image, instruction, pos_caption, neg_caption, count, points=None): """ Save uploaded image and prompt data for collection. Tries HuggingFace Dataset first, falls back to local storage. """ global hf_api timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") unique_id = str(uuid.uuid4())[:8] image_filename = f"{timestamp}_{unique_id}.jpg" # Prepare image bytes img_buffer = io.BytesIO() image.save(img_buffer, format='JPEG', quality=95) img_bytes = img_buffer.getvalue() # Data as dict (for JSON) data_dict = { "timestamp": timestamp, "image_filename": image_filename, "instruction": instruction, "pos_caption": pos_caption, "neg_caption": neg_caption, "count": count, "points": points if points else [] # normalized coordinates (0-1) } # Try HuggingFace Dataset first if hf_api: try: if upload_to_hf_dataset(img_bytes, image_filename, data_dict): print(f"Saved to HuggingFace Dataset: {image_filename}") return except Exception as e: print(f"HuggingFace upload failed, falling back to local: {e}") # Fallback to local storage try: image_path = IMAGES_DIR / image_filename image.save(image_path, "JPEG", quality=95) # Also save as JSON locally json_path = DATA_LOG_DIR / "metadata" json_path.mkdir(exist_ok=True) with open(json_path / image_filename.replace('.jpg', '.json'), 'w') as f: json.dump(data_dict, f, indent=2) # Also append to CSV for backward compatibility with open(DATA_LOG_FILE, "a", newline="") as f: writer = csv.writer(f) writer.writerow([timestamp, image_filename, instruction, pos_caption, neg_caption, count]) print(f"Saved locally: {image_filename}") except Exception as e: print(f"Error saving data: {e}") def validate_image(image): """ Validate uploaded image format. Returns (is_valid, error_message) """ if image is None: return False, "Error: Please upload an image." # Get file extension if isinstance(image, str): ext = os.path.splitext(image)[1].lower() if ext and ext not in ALLOWED_EXTENSIONS: return False, f"Error: Unsupported format '{ext}'. Only JPG and PNG are supported." return True, None def preprocess_image(image): """ Preprocess uploaded image: convert format and resize if needed. """ # Handle file path input if isinstance(image, str): image = Image.open(image) # Convert to RGB (handles RGBA, P mode, etc.) if image.mode != "RGB": if image.mode in ("RGBA", "LA", "P"): background = Image.new("RGB", image.size, (255, 255, 255)) if image.mode == "P": image = image.convert("RGBA") background.paste(image, mask=image.split()[-1] if image.mode == "RGBA" else None) image = background else: image = image.convert("RGB") # Resize if image is too large width, height = image.size if max(width, height) > MAX_IMAGE_SIZE: scale = MAX_IMAGE_SIZE / max(width, height) new_width = int(width * scale) new_height = int(height * scale) image = image.resize((new_width, new_height), Image.Resampling.LANCZOS) print(f"Resized image from {width}x{height} to {new_width}x{new_height}") return image def parse_counting_instruction(instruction: str) -> tuple[str, str]: """ Parse natural language counting instruction using Gemini 2.0 Flash. """ try: prompt = PARSING_PROMPT.format(instruction=instruction) response = gemini_model.generate_content(prompt) response_text = response.text.strip() # Clean up response - remove markdown code blocks if present if response_text.startswith("```"): response_text = response_text.split("```")[1] if response_text.startswith("json"): response_text = response_text[4:] response_text = response_text.strip() result = json.loads(response_text) # Convert lists to caption strings pos_items = result.get("A", []) neg_items = result.get("B", []) # Join items with " and " and add period pos_caption = " and ".join(pos_items) + "." if pos_items else "" neg_caption = " and ".join(neg_items) + "." if neg_items else "None." return pos_caption, neg_caption except Exception as e: print(f"Error parsing instruction: {e}") return instruction.strip() + ".", "None." def load_model(): """Load model and processor once at startup""" global model, processor, device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model_id = "yifehuang97/CountEx_KC_aug_v3_12140136_v2" model = CountEX.from_pretrained(model_id, token=os.environ.get("HF_TOKEN")) model = model.to(torch.bfloat16) model = model.to(device) model.eval() processor_id = "fushh7/llmdet_swin_tiny_hf" processor = GroundingDinoProcessor.from_pretrained(processor_id) return model, processor, device import numpy as np def discriminative_point_suppression( points, neg_points, pos_queries, neg_queries, image_size, pixel_threshold=5, similarity_threshold=0.3, ): """Discriminative Point Suppression (DPS)""" if not neg_points or not points: return points, list(range(len(points))), {} width, height = image_size N, M = len(points), len(neg_points) points_arr = np.array(points) * np.array([width, height]) neg_points_arr = np.array(neg_points) * np.array([width, height]) spatial_dist = np.linalg.norm( points_arr[:, None, :] - neg_points_arr[None, :, :], axis=-1 ) nearest_neg_idx = spatial_dist.argmin(axis=1) nearest_neg_dist = spatial_dist.min(axis=1) spatially_close = nearest_neg_dist < pixel_threshold pos_q = pos_queries / (np.linalg.norm(pos_queries, axis=-1, keepdims=True) + 1e-8) neg_q = neg_queries / (np.linalg.norm(neg_queries, axis=-1, keepdims=True) + 1e-8) matched_neg_q = neg_q[nearest_neg_idx] query_sim = (pos_q * matched_neg_q).sum(axis=-1) semantically_similar = query_sim > similarity_threshold should_suppress = spatially_close & semantically_similar keep_mask = ~should_suppress filtered_points = np.array(points)[keep_mask].tolist() filtered_indices = np.where(keep_mask)[0].tolist() suppression_info = { "nearest_neg_idx": nearest_neg_idx.tolist(), "nearest_neg_dist": nearest_neg_dist.tolist(), "query_similarity": query_sim.tolist(), "spatially_close": spatially_close.tolist(), "semantically_similar": semantically_similar.tolist(), "suppressed_indices": np.where(should_suppress)[0].tolist(), } return filtered_points, filtered_indices, suppression_info def count_objects(image, instruction, box_threshold, point_radius, point_color): """Main inference function for counting objects""" global model, processor, device # Validate image format is_valid, error_msg = validate_image(image) if not is_valid: raise gr.Error(error_msg) if model is None: load_model() # Preprocess image image = preprocess_image(image) # Parse instruction using Gemini pos_caption, neg_caption = parse_counting_instruction(instruction) parsed_info = f"Positive: {pos_caption}\nNegative: {neg_caption}" # Process positive caption pos_inputs = processor( images=image, text=pos_caption, return_tensors="pt", padding=True ) pos_inputs = pos_inputs.to(device) pos_inputs['pixel_values'] = pos_inputs['pixel_values'].to(torch.bfloat16) # Process negative caption use_neg = bool(neg_caption and neg_caption.strip() and neg_caption != '.' and neg_caption != 'None.') if not use_neg: neg_caption = "None." neg_inputs = processor( images=image, text=neg_caption, return_tensors="pt", padding=True ) neg_inputs = {k: v.to(device) for k, v in neg_inputs.items()} neg_inputs['pixel_values'] = neg_inputs['pixel_values'].to(torch.bfloat16) pos_inputs['neg_token_type_ids'] = neg_inputs['token_type_ids'] pos_inputs['neg_attention_mask'] = neg_inputs['attention_mask'] pos_inputs['neg_pixel_mask'] = neg_inputs['pixel_mask'] pos_inputs['neg_pixel_values'] = neg_inputs['pixel_values'] pos_inputs['neg_input_ids'] = neg_inputs['input_ids'] pos_inputs['use_neg'] = True # Run inference with torch.no_grad(): outputs = model(**pos_inputs) outputs["pred_points"] = outputs["pred_boxes"][:, :, :2] outputs["pred_logits"] = outputs["logits"] threshold = box_threshold if box_threshold > 0 else model.box_threshold pos_queries = outputs["pos_queries"].squeeze(0).float() neg_queries = outputs["neg_queries"].squeeze(0).float() pos_queries = pos_queries[-1].squeeze(0) neg_queries = neg_queries[-1].squeeze(0) pos_queries = pos_queries.unsqueeze(0) neg_queries = neg_queries.unsqueeze(0) results = post_process_grounded_object_detection_with_queries(outputs, pos_queries, box_threshold=threshold)[0] boxes = results["boxes"] boxes = [box.tolist() for box in boxes] points = [[box[0], box[1]] for box in boxes] neg_points = [] neg_results = None if "neg_pred_boxes" in outputs and "neg_logits" in outputs: neg_outputs = outputs.copy() neg_outputs["pred_boxes"] = outputs["neg_pred_boxes"] neg_outputs["logits"] = outputs["neg_logits"] neg_outputs["pred_points"] = outputs["neg_pred_boxes"][:, :, :2] neg_outputs["pred_logits"] = outputs["neg_logits"] neg_results = post_process_grounded_object_detection_with_queries(neg_outputs, neg_queries, box_threshold=0.5)[0] neg_boxes = neg_results["boxes"] neg_boxes = [box.tolist() for box in neg_boxes] neg_points = [[box[0], box[1]] for box in neg_boxes] pos_queries_np = results["queries"].cpu().numpy() neg_queries_np = neg_results["queries"].cpu().numpy() if neg_results else np.array([]) img_size = image.size if len(neg_points) > 0 and len(neg_queries_np) > 0: filtered_points, kept_indices, suppression_info = discriminative_point_suppression( points, neg_points, pos_queries_np, neg_queries_np, image_size=img_size, pixel_threshold=5, similarity_threshold=0.3, ) filtered_boxes = [boxes[i] for i in kept_indices] else: filtered_points = points filtered_boxes = boxes points = filtered_points boxes = filtered_boxes # Visualize results img_w, img_h = image.size img_draw = image.copy() draw = ImageDraw.Draw(img_draw) for point in points: x = point[0] * img_w y = point[1] * img_h draw.ellipse( [x - point_radius, y - point_radius, x + point_radius, y + point_radius], fill=point_color ) count = len(points) # Save uploaded data for collection save_uploaded_data(image, instruction, pos_caption, neg_caption, count, points) return img_draw, f"Count: {count}", parsed_info def count_objects_manual(image, pos_caption, neg_caption, box_threshold, point_radius, point_color): """Manual mode: directly use provided positive and negative captions.""" global model, processor, device # Validate image format is_valid, error_msg = validate_image(image) if not is_valid: raise gr.Error(error_msg) if model is None: load_model() # Preprocess image image = preprocess_image(image) if pos_caption and not pos_caption.endswith('.'): pos_caption = pos_caption + '.' if neg_caption and not neg_caption.endswith('.'): neg_caption = neg_caption + '.' if not neg_caption or neg_caption.strip() == '': neg_caption = "None." parsed_info = f"Positive: {pos_caption}\nNegative: {neg_caption}" pos_inputs = processor( images=image, text=pos_caption, return_tensors="pt", padding=True ) pos_inputs = pos_inputs.to(device) pos_inputs['pixel_values'] = pos_inputs['pixel_values'].to(torch.bfloat16) use_neg = bool(neg_caption and neg_caption.strip() and neg_caption != '.' and neg_caption != 'None.') if not use_neg: neg_caption = "None." neg_inputs = processor( images=image, text=neg_caption, return_tensors="pt", padding=True ) neg_inputs = {k: v.to(device) for k, v in neg_inputs.items()} neg_inputs['pixel_values'] = neg_inputs['pixel_values'].to(torch.bfloat16) pos_inputs['neg_token_type_ids'] = neg_inputs['token_type_ids'] pos_inputs['neg_attention_mask'] = neg_inputs['attention_mask'] pos_inputs['neg_pixel_mask'] = neg_inputs['pixel_mask'] pos_inputs['neg_pixel_values'] = neg_inputs['pixel_values'] pos_inputs['neg_input_ids'] = neg_inputs['input_ids'] pos_inputs['use_neg'] = True with torch.no_grad(): outputs = model(**pos_inputs) outputs["pred_points"] = outputs["pred_boxes"][:, :, :2] outputs["pred_logits"] = outputs["logits"] threshold = box_threshold if box_threshold > 0 else model.box_threshold pos_queries = outputs["pos_queries"].squeeze(0).float() neg_queries = outputs["neg_queries"].squeeze(0).float() pos_queries = pos_queries[-1].squeeze(0) neg_queries = neg_queries[-1].squeeze(0) pos_queries = pos_queries.unsqueeze(0) neg_queries = neg_queries.unsqueeze(0) results = post_process_grounded_object_detection_with_queries(outputs, pos_queries, box_threshold=threshold)[0] boxes = results["boxes"] boxes = [box.tolist() for box in boxes] points = [[box[0], box[1]] for box in boxes] neg_points = [] neg_results = None if "neg_pred_boxes" in outputs and "neg_logits" in outputs: neg_outputs = outputs.copy() neg_outputs["pred_boxes"] = outputs["neg_pred_boxes"] neg_outputs["logits"] = outputs["neg_logits"] neg_outputs["pred_points"] = outputs["neg_pred_boxes"][:, :, :2] neg_outputs["pred_logits"] = outputs["neg_logits"] neg_results = post_process_grounded_object_detection_with_queries(neg_outputs, neg_queries, box_threshold=0.5)[0] neg_boxes = neg_results["boxes"] neg_boxes = [box.tolist() for box in neg_boxes] neg_points = [[box[0], box[1]] for box in neg_boxes] pos_queries_np = results["queries"].cpu().numpy() neg_queries_np = neg_results["queries"].cpu().numpy() if neg_results else np.array([]) img_size = image.size if len(neg_points) > 0 and len(neg_queries_np) > 0: filtered_points, kept_indices, suppression_info = discriminative_point_suppression( points, neg_points, pos_queries_np, neg_queries_np, image_size=img_size, pixel_threshold=5, similarity_threshold=0.3, ) filtered_boxes = [boxes[i] for i in kept_indices] else: filtered_points = points filtered_boxes = boxes points = filtered_points boxes = filtered_boxes img_w, img_h = image.size img_draw = image.copy() draw = ImageDraw.Draw(img_draw) for point in points: x = point[0] * img_w y = point[1] * img_h draw.ellipse( [x - point_radius, y - point_radius, x + point_radius, y + point_radius], fill=point_color ) count = len(points) instruction = f"[MANUAL] pos: {pos_caption} | neg: {neg_caption}" save_uploaded_data(image, instruction, pos_caption, neg_caption, count, points) return img_draw, f"Count: {count}", parsed_info def create_demo(): with gr.Blocks(title="CountEx: Discriminative Visual Counting") as demo: gr.Markdown(""" # CountEx: Fine-Grained Counting via Exemplars and Exclusion Count specific objects in images using text prompts with exclusion capability. """) current_mode = gr.State(value="natural_language") with gr.Row(): with gr.Column(scale=1): input_image = gr.Image(type="filepath", label="Input Image (JPG, PNG only)") with gr.Tabs() as input_tabs: with gr.TabItem("Natural Language", id=0) as tab_nl: instruction = gr.Textbox( label="Counting Instruction", placeholder="e.g., Count apples, not green apples", value="Count apples, not green apples", lines=2 ) gr.Markdown(""" **Examples:** - "Count apples, not green apples" - "Count red and black beans, exclude white beans" - "Count people, not women" """) with gr.TabItem("Manual Input", id=1) as tab_manual: pos_caption = gr.Textbox( label="Positive Prompt (objects to count)", placeholder="e.g., apple", value="apple." ) neg_caption = gr.Textbox( label="Negative Prompt (objects to exclude)", placeholder="e.g., green apple", value="None." ) submit_btn = gr.Button("Count Objects", variant="primary", size="lg") with gr.Accordion("Advanced Settings", open=False): box_threshold = gr.Slider( minimum=0.0, maximum=1.0, value=0.42, step=0.01, label="Threshold" ) point_radius = gr.Slider( minimum=1, maximum=20, value=5, step=1, label="Point Radius" ) point_color = gr.Dropdown( choices=["blue", "red", "green", "yellow", "cyan", "magenta", "white", "orange"], value="blue", label="Point Color" ) with gr.Column(scale=1): output_image = gr.Image(type="pil", label="Result") count_output = gr.Textbox(label="Count Result") parsed_output = gr.Textbox(label="Parsed Captions", lines=2) gr.Markdown("### Examples (Natural Language)") gr.Examples( examples=[ ["examples/apples.png", "Count apples, not green apples"], ["examples/apples.png", "Count apples, exclude red apples"], ["examples/apple.jpg", "Count green apples"], ["examples/apple.jpg", "Count apples, exclude green apples"], ["examples/apple.jpg", "Count apples, exclude red apples"], ["examples/blue_straw_peach.png", "Count blueberries"], ["examples/blue_straw_peach.png", "Count leaf"], ["examples/blue_straw_peach.png", "Count blueberries and cherry"], ["examples/blue_straw_peach.png", "Count blueberries and cherry and strawberry"], ["examples/black_beans.jpg", "Count black beans and soy beans"], ["examples/black_beans.jpg", "Count beans"], ["examples/black_beans.jpg", "Count pig"], ["examples/candy.jpg", "Count brown coffee candy, exclude black coffee candy"], ["examples/candy.jpg", "Count candy"], ["examples/candy.jpg", "Count brown coffee candy and black coffee candy"], ["examples/candy.jpg", "Count sausage"], ["examples/strawberry.jpg", "Count blueberries and strawberry"], ["examples/strawberry.jpg", "Count book"], ["examples/strawberry2.jpg", "Count blueberries, exclude strawberry"], ["examples/women.jpg", "Count people, not women"], ["examples/women.jpg", "Count people, not man"], ["examples/boat-1.jpg", "Count boats, exclude blue boats"], ["examples/boat-1.jpg", "Count boats, exclude red boats"], ], inputs=[input_image, instruction], outputs=[output_image, count_output, parsed_output], fn=lambda img, instr: count_objects(img, instr, 0.42, 5, "blue"), cache_examples=False, ) def set_mode_nl(): return "natural_language" def set_mode_manual(): return "manual" tab_nl.select(fn=set_mode_nl, outputs=[current_mode]) tab_manual.select(fn=set_mode_manual, outputs=[current_mode]) def handle_submit(mode, image, instr, pos_cap, neg_cap, threshold, radius, color): if mode == "natural_language": return count_objects(image, instr, threshold, radius, color) else: return count_objects_manual(image, pos_cap, neg_cap, threshold, radius, color) submit_btn.click( fn=handle_submit, inputs=[current_mode, input_image, instruction, pos_caption, neg_caption, box_threshold, point_radius, point_color], outputs=[output_image, count_output, parsed_output] ) return demo if __name__ == "__main__": # Initialize HuggingFace API print("Initializing HuggingFace API...") init_hf_api() # Load model at startup print("Loading model...") load_model() print("Model loaded!") # Create and launch demo demo = create_demo() demo.launch()