Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Image Tagging Server using ONNX and FastAPI. | |
| This script sets up a web server that provides endpoints for tagging images | |
| using a pre-trained ONNX model. It supports single image processing, batch | |
| processing, and can download model artifacts from a Hugging Face repository. | |
| """ | |
| import argparse | |
| import logging | |
| import math | |
| import os | |
| import pathlib | |
| import time | |
| import types | |
| import typing | |
| from contextlib import asynccontextmanager | |
| from io import BytesIO | |
| from pathlib import Path | |
| from typing import Any, Dict, List | |
| import huggingface_hub | |
| import numpy as np | |
| import pandas as pd | |
| import timm | |
| import torch | |
| import uvicorn | |
| from fastapi import FastAPI, File, HTTPException, UploadFile | |
| from fastapi.responses import RedirectResponse | |
| from PIL import Image | |
| from pydantic import BaseModel, Field | |
| from pydantic_settings import BaseSettings | |
| from timm.data import create_transform, resolve_data_config | |
| from torch import nn | |
| from torch.nn import functional as F | |
| # --- Configuration Management --- | |
| class Settings(BaseSettings): | |
| """Manages application configuration using Pydantic.""" | |
| host: str = Field(default="0.0.0.0", description="Server host.") | |
| port: int = Field(default=8080, description="Server port.") | |
| instances: int = Field(default=1, description="Number of uvicorn workers.") | |
| triton: int = Field(default=0, description="Enable triton / torch.compile()") | |
| log_level: str = Field(default="INFO", description="Logging level.") | |
| model_repo: str = Field( | |
| default=None, description="HuggingFace repository for model files." | |
| ) | |
| model_file: str = Field( | |
| default="model.safetensors", description="ONNX model filename." | |
| ) | |
| tags_file: str = Field( | |
| default="selected_tags.csv", description="CSV file with tag names." | |
| ) | |
| thresholds_file: str = Field( | |
| default="thresholds.csv", description="CSV file with category thresholds." | |
| ) | |
| backend: str = Field( | |
| default="cpu", | |
| description="Inference backend ('cpu', 'cuda', 'tensorrt').", | |
| pattern="^(cpu|cuda|tensorrt)$", | |
| ) | |
| token: str | None = Field(default=None, description="HuggingFace Token.") | |
| class Config: | |
| env_prefix = "TAGGER_" | |
| # --- Logging Setup --- | |
| class CustomFormatter(logging.Formatter): | |
| """A custom log formatter with colors for different log levels.""" | |
| LEVEL_COLORS = { | |
| logging.DEBUG: "\x1b[38;20m", # Grey | |
| logging.INFO: "\x1b[32m", # Green | |
| logging.WARNING: "\x1b[33;20m", # Yellow | |
| logging.ERROR: "\x1b[31;20m", # Red | |
| logging.CRITICAL: "\x1b[31;1m", # Bold Red | |
| } | |
| RESET = "\x1b[0m" | |
| def format(self, record: logging.LogRecord) -> str: | |
| color = self.LEVEL_COLORS.get(record.levelno, "") | |
| record.levelprefix = f"{color}{record.levelname:<8}{self.RESET}" | |
| return super().format(record) | |
| def setup_logging(log_level: str): | |
| """Configures the root logger.""" | |
| logger = logging.getLogger() | |
| logger.setLevel(log_level) | |
| handler = logging.StreamHandler() | |
| handler.setFormatter(CustomFormatter("%(levelprefix)s | %(message)s")) | |
| logger.handlers = [handler] | |
| # Suppress verbose logs from other libraries | |
| logging.getLogger("uvicorn").handlers = [] | |
| logging.getLogger("uvicorn.access").handlers = [] | |
| return logger | |
| def pil_ensure_rgb(image: Image.Image) -> Image.Image: | |
| if image.mode not in ["RGB", "RGBA"]: | |
| image = ( | |
| image.convert("RGBA") | |
| if "transparency" in image.info | |
| else image.convert("RGB") | |
| ) | |
| if image.mode == "RGBA": | |
| canvas = Image.new("RGBA", image.size, (255, 255, 255)) | |
| canvas.alpha_composite(image) | |
| image = canvas.convert("RGB") | |
| return image | |
| def pil_pad_square(image: Image.Image) -> Image.Image: | |
| w, h = image.size | |
| px = max(w, h) | |
| canvas = Image.new("RGB", (px, px), (255, 255, 255)) | |
| canvas.paste(image, ((px - w) // 2, (px - h) // 2)) | |
| return canvas | |
| logger = setup_logging("DEBUG") | |
| # --- API Models (Pydantic) --- | |
| class Timing(BaseModel): | |
| total_seconds: float | |
| processing_seconds: float | |
| TAG_RESPONSE = dict[str, list[dict[str, Any]]] | |
| class TaggingResponse(BaseModel): | |
| tags: TAG_RESPONSE | |
| timing: Timing | |
| class BatchTaggingResponse(BaseModel): | |
| batch_size: int | |
| results: list[TAG_RESPONSE] | |
| timing: Timing | |
| class StatusResponse(BaseModel): | |
| status: str | |
| model_name: str | None | |
| class TaggerArgs(BaseModel): | |
| tags_threshold: bool = False | |
| # --- Core Logic: Tags & Tagger Classes --- | |
| class Tags: | |
| """Handles loading and processing of tag data and thresholds.""" | |
| DEFAULT_CATEGORIES = { | |
| 0: {"name": "general", "threshold": 0.35}, | |
| 4: {"name": "character", "threshold": 0.85}, | |
| 9: {"name": "rating", "threshold": 0.0}, | |
| } | |
| def __init__(self, labels_path: Path, threshold_path: Path | None = None): | |
| logger.info(f"Loading labels from '{labels_path}'...") | |
| start_time = time.time() | |
| tags_df = pd.read_csv(labels_path) | |
| self.tag_names = tags_df["name"].tolist() | |
| self.tag_names_ndarray = np.array(self.tag_names) | |
| self.categories: Dict[int, Dict[str, Any]] = {} | |
| if "best_threshold" in tags_df: | |
| self.tag_thresholds = np.array(tags_df["best_threshold"].tolist()) | |
| else: | |
| self.tag_thresholds = None | |
| if ( | |
| threshold_path | |
| and threshold_path.is_file() | |
| and threshold_path.stat().st_size > 0 | |
| ): | |
| logger.info(f"Loading thresholds from '{threshold_path}'.") | |
| for item in pd.read_csv(threshold_path).to_dict("records"): | |
| if item["category"] not in self.categories: | |
| self.categories[item["category"]] = { | |
| "name": item["name"], | |
| "threshold": item["threshold"], | |
| } | |
| else: | |
| logger.info("No valid threshold file found. Using default categories.") | |
| self.categories = self.DEFAULT_CATEGORIES | |
| for cat_id, cat_info in self.categories.items(): | |
| cat_info["indices"] = list(np.where(tags_df["category"] == cat_id)[0]) | |
| logger.info( | |
| f"Loaded {len(self.tag_names)} tags and {len(self.categories)} categories in {time.time() - start_time:.2f}s." | |
| ) | |
| def process_predictions( | |
| self, | |
| preds: np.ndarray, | |
| tag_indices: List[int], | |
| threshold: float, | |
| tags_threshold: bool = False, | |
| ) -> List[List[dict[str, Any]]]: | |
| """Filters and sorts predictions based on a threshold.""" | |
| tag_names = self.tag_names_ndarray | |
| # preds = np.asarray(preds) | |
| tag_scores = preds[:, tag_indices] | |
| tag_names_sel = tag_names[tag_indices] | |
| if tags_threshold and self.tag_thresholds is not None: | |
| mask = tag_scores > self.tag_thresholds[tag_indices] | |
| tag_scores = np.where(mask, tag_scores, -np.inf) | |
| else: | |
| if threshold is not None: | |
| mask = tag_scores > threshold | |
| tag_scores = np.where(mask, tag_scores, -np.inf) | |
| sorted_idx = np.argsort(-tag_scores, axis=1) | |
| sorted_names = tag_names_sel[sorted_idx] | |
| sorted_scores = np.take_along_axis(tag_scores, sorted_idx, axis=1) | |
| return [ | |
| [ | |
| {"name": name, "confidence": float(score)} | |
| for name, score in zip(names, scores) | |
| if not math.isinf(float(score)) | |
| ] | |
| for names, scores in zip(sorted_names, sorted_scores) | |
| ] | |
| def resolve_batch_probs( | |
| self, probs: np.ndarray, tags_threshold: bool = False | |
| ) -> list[dict[str, list[dict[str, Any]]]]: | |
| """Resolves raw probabilities into categorized tag predictions.""" | |
| logger.info(f"Shapery: {probs.shape[0]}") | |
| results_batched: dict[str, Any] = { | |
| cat_info["name"]: [] for cat_info in self.categories.values() | |
| } | |
| for cat_info in self.categories.values(): | |
| for _, result in enumerate( | |
| self.process_predictions( | |
| probs, | |
| cat_info["indices"], | |
| cat_info["threshold"], | |
| tags_threshold=tags_threshold, | |
| ) | |
| ): | |
| # {k: [dic[k] for dic in LD] for k in LD[0]} | |
| results_batched[cat_info["name"]].append(result) | |
| results_list = [ | |
| dict(zip(results_batched, t)) for t in zip(*results_batched.values()) | |
| ] | |
| return results_list | |
| class Tagger: | |
| """Manages the ONNX model, image preprocessing, and inference.""" | |
| def __init__( | |
| self, | |
| model_repo: str, | |
| tags: Tags, | |
| backend: str = "cpu", | |
| instances: int = 1, | |
| triton: bool = False, | |
| ): | |
| self.tags_data = tags | |
| self.model_repo = model_repo | |
| self.device = torch.device( | |
| "cuda" if backend == "cuda" and torch.cuda.is_available() else "cpu" | |
| ) | |
| logger.info(f"Loading model from HuggingFace repo: {model_repo}...") | |
| self.model: nn.Module = timm.create_model( | |
| "hf-hub:" + model_repo, pretrained=False | |
| ) | |
| self.swap_colorspace = False | |
| if model_repo.startswith("animetimm/"): | |
| logger.warning("Detected animetimm model. Enabling color swap.") | |
| self.swap_colorspace = True | |
| state_dict = timm.models.load_state_dict_from_hf(model_repo) | |
| self.model.load_state_dict(state_dict) | |
| self.model = self.model.eval().to(self.device) | |
| if triton: | |
| self.model.compile( | |
| fullgraph=True, | |
| ) | |
| self.transform = create_transform( | |
| **resolve_data_config(self.model.pretrained_cfg, model=self.model) | |
| ) | |
| self.model = nn.DataParallel(self.model, device_ids=list(range(instances))) | |
| logger.info("Model loaded and ready.") | |
| def _create_model( | |
| self, model_repo: str, backend: str, index: int | |
| ) -> torch.nn.Module: | |
| """Creates and validates the ONNX Runtime inference session.""" | |
| model: torch.nn.Module = timm.create_model( | |
| "hf-hub:" + model_repo, pretrained=False | |
| ) | |
| state_dict = timm.models.load_state_dict_from_hf(model_repo) | |
| model.load_state_dict(state_dict) | |
| model = model.eval() | |
| if backend == "cuda": | |
| model = model.to(torch.device(backend, index), dtype=torch.float32) | |
| # model.compile( | |
| # fullgraph=True, | |
| # ) | |
| return model | |
| def preprocess_batch(self, image_batch: np.ndarray) -> torch.Tensor: | |
| """Converts NHWC float32 [0-1] NumPy images to a PyTorch tensor in NCHW RGB format.""" | |
| pil_images = [ | |
| Image.fromarray((img * 255).astype(np.uint8)) for img in image_batch | |
| ] | |
| images = [pil_pad_square(pil_ensure_rgb(im)) for im in pil_images] | |
| tensors = [self.transform(im) for im in images] | |
| batch = torch.stack(tensors, dim=0) | |
| if self.swap_colorspace: | |
| print(batch.shape) | |
| batch = batch[:, [2, 1, 0], :, :] | |
| return batch.to(self.device) | |
| def predict_batch( | |
| self, image_batch: np.ndarray, tags_threshold=False | |
| ) -> List[dict[str, list[dict[str, Any]]]]: | |
| batch_tensor = self.preprocess_batch(image_batch) | |
| with ( | |
| torch.inference_mode(), | |
| torch.autocast(device_type="cuda", dtype=torch.bfloat16), | |
| ): | |
| logits = self.model(batch_tensor) | |
| probs = F.sigmoid(logits).cpu().to(torch.float32).numpy() | |
| resolved = self.tags_data.resolve_batch_probs( | |
| probs, tags_threshold=tags_threshold | |
| ) | |
| return resolved | |
| # --- FastAPI Application Setup --- | |
| class AppState: | |
| """Container for application state, like the tagger instance.""" | |
| def __init__(self, settings: Settings): | |
| self.settings = settings | |
| self.tagger: Tagger | None = None | |
| def download_file(repo: str, filename: str, output_path: Path): | |
| """Downloads a file from Hugging Face Hub if it doesn't exist.""" | |
| if not output_path.exists(): | |
| logger.info(f"Downloading '{filename}' from repo '{repo}'...") | |
| try: | |
| path = huggingface_hub.hf_hub_download( | |
| repo, | |
| filename, | |
| local_dir=output_path.parent, | |
| local_dir_use_symlinks=False, | |
| ) | |
| # Ensure the downloaded file is at the expected path | |
| if Path(path) != output_path: | |
| os.rename(path, output_path) | |
| except Exception as e: | |
| raise FileNotFoundError( | |
| f"Failed to download '{filename}' from '{repo}': {e}" | |
| ) from e | |
| async def lifespan(app: FastAPI): | |
| """Initializes the Tagger on startup and handles cleanup.""" | |
| settings: Settings = app.state.settings | |
| model_dir = Path("models") | |
| model_dir.mkdir(exist_ok=True) | |
| if settings.model_repo and pathlib.Path(settings.model_repo).is_dir(): | |
| model_dir = pathlib.Path(settings.model_repo) | |
| elif settings.model_repo: | |
| model_dir = model_dir / pathlib.Path(settings.model_repo) | |
| logger.info(f"Using directory: {model_dir} for storage...") | |
| tags_path = model_dir / settings.tags_file | |
| thresholds_path = model_dir / settings.thresholds_file | |
| if settings.model_repo and not pathlib.Path(settings.model_repo).is_dir(): | |
| try: | |
| download_file(settings.model_repo, settings.tags_file, tags_path) | |
| # Thresholds file is optional, so don't fail if it's not there | |
| try: | |
| download_file( | |
| settings.model_repo, settings.thresholds_file, thresholds_path | |
| ) | |
| except FileNotFoundError: | |
| logger.warning( | |
| f"Optional thresholds file '{settings.thresholds_file}' not found in repo." | |
| ) | |
| except FileNotFoundError as e: | |
| logger.critical(f"Could not start server: {e}") | |
| # Exit if critical files are missing | |
| return | |
| if not tags_path.is_file(): | |
| logger.critical( | |
| "Model or tags file not found, and no model repository was specified. Exiting." | |
| ) | |
| return | |
| try: | |
| logger.info("Initializing tagger...") | |
| tags = Tags(labels_path=tags_path, threshold_path=thresholds_path) | |
| app.state.tagger = Tagger( | |
| settings.model_repo, | |
| tags, | |
| settings.backend, | |
| instances=settings.instances, | |
| triton=True if settings.triton else False, | |
| ) | |
| logger.info("Tagger initialized successfully. Server is ready.") | |
| except (ValueError, RuntimeError) as e: | |
| logger.critical(f"Failed to initialize tagger: {e}") | |
| return | |
| yield | |
| # --- Cleanup --- | |
| app.state.tagger = None | |
| logger.info("Server shutting down.") | |
| def create_app(settings: Settings) -> FastAPI: | |
| """Creates and configures the FastAPI application instance.""" | |
| app = FastAPI( | |
| title="Image Tagger API", | |
| description="An API for tagging images using an ONNX model.", | |
| version="1.0.1", # Incremented version | |
| lifespan=lifespan, | |
| docs_url="/docs", | |
| ) | |
| app.state = AppState(settings) | |
| return app | |
| # --- Dependency for Endpoints --- | |
| def get_tagger(app: FastAPI) -> Tagger: | |
| """A dependency that provides the initialized tagger instance.""" | |
| if not app.state.tagger: | |
| raise HTTPException( | |
| status_code=503, | |
| detail="Tagger is not initialized. The server may be starting up or has encountered an error.", | |
| ) | |
| return app.state.tagger | |
| # --- API Endpoints --- | |
| def add_endpoints(app: FastAPI): | |
| tagger_dependency = lambda: get_tagger(app) | |
| # Root welcome/docs page | |
| async def root(): | |
| if app.docs_url: | |
| return RedirectResponse(url=app.docs_url) | |
| elif app.redoc_url: | |
| return RedirectResponse(url=app.redoc_url) | |
| return HTMLResponse( | |
| content="<h1>Welcome to the Tagger API</h1><p>Use /batch to tag images.</p>", | |
| status_code=200 | |
| ) | |
| # Tagging endpoint at /batch | |
| async def tag_batch( | |
| tags_threshold: TaggerArgs = TaggerArgs(), | |
| file: UploadFile = File( | |
| ..., description="A .npz file containing a batch of images in NHWC format." | |
| ), | |
| ): | |
| if not file.filename or not file.filename.endswith(".npz"): | |
| raise HTTPException( | |
| status_code=400, | |
| detail="Only .npz files are supported for batch processing.", | |
| ) | |
| start_time = time.time() | |
| tagger = tagger_dependency() | |
| logger.info(f"Processing batch file: {file.filename}") | |
| contents = await file.read() | |
| with np.load(BytesIO(contents)) as npz: | |
| batch = npz[npz.files[0]] | |
| logger.info(f"Loaded batch of shape: {batch.shape}") | |
| process_start = time.time() | |
| try: | |
| results = tagger.predict_batch(batch, tags_threshold=tags_threshold) | |
| except ValueError as e: | |
| raise HTTPException(status_code=400, detail=str(e)) | |
| processing_time = time.time() - process_start | |
| logger.info(f"Processed batch in {processing_time:.2f}s") | |
| return BatchTaggingResponse( | |
| batch_size=len(results), | |
| results=results, | |
| timing=Timing( | |
| total_seconds=time.time() - start_time, | |
| processing_seconds=processing_time, | |
| ), | |
| ) | |
| # Status endpoint | |
| async def status(): | |
| tagger = tagger_dependency() | |
| return StatusResponse( | |
| status="ok", | |
| model_name=tagger.model_repo, | |
| ) | |
| def determine_type(field_type: type): | |
| if type(field_type) is types.UnionType: | |
| return typing.get_args(field_type)[0] | |
| return field_type | |
| # --- Main Execution --- | |
| def main(): | |
| """Parses arguments, sets up the app, and runs the server.""" | |
| parser = argparse.ArgumentParser(description="Image Tagging Server") | |
| # Add arguments that correspond to the Settings fields | |
| for field_name, field in Settings.model_fields.items(): | |
| parser.add_argument( | |
| f"--{field_name.replace('_', '-')}", | |
| type=determine_type(field.annotation), # Basic type handling for argparse | |
| default=field.default, | |
| help=field.description, | |
| ) | |
| args = parser.parse_args() | |
| # Create settings from a combination of args, env vars, and defaults | |
| settings = Settings(**vars(args)) | |
| global logger | |
| logger = setup_logging(settings.log_level.upper()) | |
| if settings.token: | |
| import os | |
| logger.info("Using custom token...") | |
| os.environ["HF_TOKEN"] = settings.token | |
| app = create_app(settings) | |
| add_endpoints(app) | |
| uvicorn.run( | |
| app, | |
| host=settings.host, | |
| port=settings.port, | |
| log_config=None, # Use our custom logger | |
| ) | |
| if __name__ == "__main__": | |
| main() | |