Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import onnxruntime as ort | |
| import numpy as np | |
| from PIL import Image | |
| import json | |
| from huggingface_hub import hf_hub_download | |
| import torchvision.transforms as transforms | |
| # Constants | |
| MODEL_REPO = "AngelBottomless/camie-tagger-onnxruntime" | |
| MODEL_FILE = "camie_tagger_initial.onnx" | |
| META_FILE = "metadata.json" | |
| DEFAULT_THRESHOLD = 0.32626262307167053 # Default value if slider is not used | |
| # Download model and metadata from Hugging Face Hub | |
| model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE, cache_dir=".") | |
| meta_path = hf_hub_download(repo_id=MODEL_REPO, filename=META_FILE, cache_dir=".") | |
| # Initialize ONNX Runtime session and load metadata | |
| session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"]) | |
| with open(meta_path, "r", encoding="utf-8") as f: | |
| metadata = json.load(f) | |
| def escape_tag(tag: str) -> str: | |
| """Escape underscores and parentheses for Markdown.""" | |
| return tag.replace("_", " ").replace("(", r"\\(").replace(")", r"\\)") | |
| def preprocess_image(pil_image: Image.Image) -> np.ndarray: | |
| """Process an image for inference using same preprocessing as training""" | |
| image_size=512 | |
| # Initialize the same transform used during training | |
| transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| ]) | |
| img = pil_image # Use the PIL image directly | |
| # Convert RGBA or Palette images to RGB | |
| if img.mode in ('RGBA', 'P'): | |
| img = img.convert('RGB') | |
| # Get original dimensions | |
| width, height = img.size | |
| aspect_ratio = width / height | |
| # Calculate new dimensions to maintain aspect ratio | |
| if aspect_ratio > 1: | |
| new_width = image_size | |
| new_height = int(new_width / aspect_ratio) | |
| else: | |
| new_height = image_size | |
| new_width = int(new_height * aspect_ratio) | |
| # Resize with LANCZOS filter | |
| img = img.resize((new_width, new_height), Image.Resampling.LANCZOS) | |
| # Create new image with padding | |
| new_image = Image.new('RGB', (image_size, image_size), (0, 0, 0)) | |
| paste_x = (image_size - new_width) // 2 | |
| paste_y = (image_size - new_height) // 2 | |
| new_image.paste(img, (paste_x, paste_y)) | |
| # Apply transforms (without normalization) | |
| img_tensor = transform(new_image) | |
| return img_tensor.numpy() # Convert the PyTorch tensor to NumPy array | |
| def run_inference(pil_image: Image.Image) -> np.ndarray: | |
| """ | |
| Preprocess the image and run the ONNX model inference. | |
| Returns the refined logits as a numpy array. | |
| """ | |
| input_tensor = preprocess_image(pil_image) | |
| input_name = session.get_inputs()[0].name | |
| # Expand dimensions to make it (1, C, H, W) | |
| input_tensor_expanded = np.expand_dims(input_tensor, axis=0) | |
| # Only refined_logits are used (initial_logits is ignored) | |
| _, refined_logits = session.run(None, {input_name: input_tensor_expanded}) | |
| return refined_logits[0] | |
| def get_tags(refined_logits: np.ndarray, metadata: dict, default_threshold: float): | |
| """ | |
| Compute probabilities from logits and collect tag predictions. | |
| Returns: | |
| results_by_cat: Dictionary mapping each category to a list of (tag, probability) above its threshold. | |
| prompt_tags_by_cat: Dictionary for prompt-style output (character, general). | |
| all_artist_tags: All artist tags (with probabilities) regardless of threshold. | |
| """ | |
| probs = 1 / (1 + np.exp(-refined_logits)) | |
| idx_to_tag = metadata["idx_to_tag"] | |
| tag_to_category = metadata.get("tag_to_category", {}) | |
| category_thresholds = metadata.get("category_thresholds", {}) | |
| results_by_cat = {} | |
| # For prompt style, only include character and general tags (artists handled separately) | |
| prompt_tags_by_cat = {"character": [], "general": []} | |
| all_artist_tags = [] | |
| for idx, prob in enumerate(probs): | |
| tag = idx_to_tag[str(idx)] | |
| cat = tag_to_category.get(tag, "unknown") | |
| thresh = category_thresholds.get(cat, default_threshold) | |
| if cat == "artist": | |
| all_artist_tags.append((tag, float(prob))) | |
| if float(prob) >= thresh: | |
| results_by_cat.setdefault(cat, []).append((tag, float(prob))) | |
| if cat in prompt_tags_by_cat: | |
| prompt_tags_by_cat[cat].append((tag, float(prob))) | |
| return results_by_cat, prompt_tags_by_cat, all_artist_tags | |
| def format_prompt_tags(prompt_tags_by_cat: dict, all_artist_tags: list) -> str: | |
| """ | |
| Format the tags for prompt-style output. | |
| Only the top artist tag is shown (regardless of threshold), and all character and general tags are shown. | |
| Returns a comma-separated string of escaped tags. | |
| """ | |
| # Always select the best artist tag from all_artist_tags, regardless of threshold. | |
| best_artist_tag = None | |
| if all_artist_tags: | |
| best_artist = max(all_artist_tags, key=lambda item: item[1]) | |
| best_artist_tag = escape_tag(best_artist[0]) | |
| # Sort character and general tags by probability (descending) | |
| for cat in prompt_tags_by_cat: | |
| prompt_tags_by_cat[cat].sort(key=lambda x: x[1], reverse=True) | |
| character_tags = [escape_tag(tag) for tag, _ in prompt_tags_by_cat.get("character", [])] | |
| general_tags = [escape_tag(tag) for tag, _ in prompt_tags_by_cat.get("general", [])] | |
| prompt_tags = [] | |
| if best_artist_tag: | |
| prompt_tags.append(best_artist_tag) | |
| prompt_tags.extend(character_tags) | |
| prompt_tags.extend(general_tags) | |
| return ", ".join(prompt_tags) if prompt_tags else "No tags predicted." | |
| def format_detailed_output(results_by_cat: dict, all_artist_tags: list) -> str: | |
| """ | |
| Format the tags for detailed output. | |
| Returns a Markdown-formatted string listing tags by category. | |
| """ | |
| if not results_by_cat: | |
| return "No tags predicted for this image." | |
| # Include an artist tag even if below threshold | |
| if "artist" not in results_by_cat and all_artist_tags: | |
| best_artist_tag, best_artist_prob = max(all_artist_tags, key=lambda item: item[1]) | |
| results_by_cat["artist"] = [(best_artist_tag, best_artist_prob)] | |
| lines = ["**Predicted Tags by Category:** \n"] | |
| for cat, tag_list in results_by_cat.items(): | |
| tag_list.sort(key=lambda x: x[1], reverse=True) | |
| lines.append(f"**Category: {cat}** – {len(tag_list)} tags") | |
| for tag, prob in tag_list: | |
| lines.append(f"- {escape_tag(tag)} (Prob: {prob:.3f})") | |
| lines.append("") # blank line between categories | |
| return "\n".join(lines) | |
| def tag_image(pil_image: Image.Image, output_format: str, threshold: float) -> str: | |
| """ | |
| Run inference on the image and return formatted tags based on the chosen output format. | |
| The slider value (threshold) overrides the default threshold for tag selection. | |
| """ | |
| if pil_image is None: | |
| return "Please upload an image." | |
| refined_logits = run_inference(pil_image) | |
| results_by_cat, prompt_tags_by_cat, all_artist_tags = get_tags(refined_logits, metadata, default_threshold=threshold) | |
| if output_format == "Prompt-style Tags": | |
| return format_prompt_tags(prompt_tags_by_cat, all_artist_tags) | |
| else: | |
| return format_detailed_output(results_by_cat, all_artist_tags) | |
| # Build the Gradio Blocks UI | |
| demo = gr.Blocks(theme="gradio/soft") | |
| with demo: | |
| gr.Markdown( | |
| "# 🏷️ Camie Tagger – Anime Image Tagging\n" | |
| "This demo uses an ONNX model of Camie Tagger to label anime illustrations with tags. " | |
| "Upload an image, adjust the threshold, and click **Tag Image** to see predictions." | |
| ) | |
| gr.Markdown( | |
| "*(Note: In prompt-style output, only the top artist tag is displayed along with all character and general tags.)*" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_in = gr.Image(type="pil", label="Input Image") | |
| format_choice = gr.Radio( | |
| choices=["Prompt-style Tags", "Detailed Output"], | |
| value="Prompt-style Tags", | |
| label="Output Format" | |
| ) | |
| # Slider to modify the default threshold value used in inference. | |
| threshold_slider = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| step=0.05, | |
| value=DEFAULT_THRESHOLD, | |
| label="Threshold" | |
| ) | |
| tag_button = gr.Button("🔍 Tag Image") | |
| with gr.Column(): | |
| output_box = gr.Markdown("") # Markdown output for formatted results | |
| # Pass the threshold_slider value into the tag_image function | |
| tag_button.click(fn=tag_image, inputs=[image_in, format_choice, threshold_slider], outputs=output_box) | |
| gr.Markdown( | |
| "----\n" | |
| "**Model:** [Camie Tagger ONNX](https://huggingface.co/AngelBottomless/camie-tagger-onnxruntime) • " | |
| "**Base Model:** Camais03/camie-tagger (61% F1 on 70k tags) • " | |
| "**ONNX Runtime:** for efficient CPU inference • " | |
| "*Demo built with Gradio Blocks.*" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |