import os import time import threading import torch import base64 import io import requests import uuid import numpy as np from typing import List, Dict, Any, Optional, Union from fastapi import FastAPI, HTTPException, Depends, Request from fastapi.responses import HTMLResponse, JSONResponse from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates from pydantic import BaseModel, Field, field_validator from dotenv import load_dotenv from huggingface_hub import snapshot_download from transformers import AutoTokenizer, AutoModelForCausalLM from collections import deque from PIL import Image from tensorflow.keras.models import load_model from urllib.request import urlretrieve import uvicorn import logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) load_dotenv() os.makedirs("templates", exist_ok=True) os.makedirs("static", exist_ok=True) os.makedirs("image_model", exist_ok=True) app = FastAPI( title="Multimodal AI Content Moderation API", description="An advanced, multilingual, and multimodal content moderation API.", version="1.0.0" ) request_times = deque(maxlen=100) concurrent_requests = 0 request_lock = threading.Lock() @app.middleware("http") async def track_metrics(request: Request, call_next): global concurrent_requests with request_lock: concurrent_requests += 1 start_time = time.time() response = await call_next(request) process_time = time.time() - start_time request_times.append(process_time) with request_lock: concurrent_requests -= 1 return response app.mount("/static", StaticFiles(directory="static"), name="static") templates = Jinja2Templates(directory="templates") device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Using device: {device}") def download_file(url, path): if not os.path.exists(path): logger.info(f"Downloading {os.path.basename(path)}...") urlretrieve(url, path) logger.info("Downloading and loading models...") MODELS = {} logger.info("Loading text moderation model: detoxify-multilingual") from detoxify import Detoxify MODELS['detoxify-multilingual'] = Detoxify('multilingual', device=device) logger.info("Detoxify model loaded.") GEMMA_REPO = "daniel-dona/gemma-3-270m-it" LOCAL_GEMMA_DIR = os.path.join(os.getcwd(), "gemma_model") os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1") def ensure_local_model(repo_id: str, local_dir: str) -> str: os.makedirs(local_dir, exist_ok=True) snapshot_download( repo_id=repo_id, local_dir=local_dir, local_dir_use_symlinks=False, resume_download=True, ) return local_dir logger.info("Loading text moderation model: gemma-3-270m-it") gemma_path = ensure_local_model(GEMMA_REPO, LOCAL_GEMMA_DIR) gemma_tokenizer = AutoTokenizer.from_pretrained(gemma_path, local_files_only=True) gemma_model = AutoModelForCausalLM.from_pretrained( gemma_path, local_files_only=True, torch_dtype=torch.float32, device_map=device ) gemma_model.eval() MODELS['gemma-3-270m-it'] = (gemma_model, gemma_tokenizer) logger.info("Gemma model loaded.") NSFW_MODEL_URL = "https://teachablemachine.withgoogle.com/models/gJOADmf_u/keras_model.h5" NSFW_LABELS_URL = "https://teachablemachine.withgoogle.com/models/gJOADmf_u/labels.txt" NSFW_MODEL_PATH = "image_model/keras_model.h5" NSFW_LABELS_PATH = "image_model/labels.txt" download_file(NSFW_MODEL_URL, NSFW_MODEL_PATH) download_file(NSFW_LABELS_URL, NSFW_LABELS_PATH) logger.info("Loading image moderation model: nsfw-image-classifier") nsfw_model = load_model(NSFW_MODEL_PATH, compile=False) with open(NSFW_LABELS_PATH, "r") as f: nsfw_labels = [line.strip().split(' ')[1] for line in f] MODELS['nsfw-image-classifier'] = (nsfw_model, nsfw_labels) logger.info("NSFW image model loaded.") class InputItem(BaseModel): text: Optional[str] = None image_url: Optional[str] = None image_base64: Optional[str] = None @field_validator('*') @classmethod def check_one_field(cls, v, info): if sum(1 for value in info.data.values() if value is not None) > 1: raise ValueError("Only one of text, image_url, or image_base64 can be provided.") return v class ModerationRequest(BaseModel): input: Union[str, List[Union[str, InputItem]]] = Field(..., max_length=10) model: str = "auto" class ModerationResponse(BaseModel): id: str model: str results: List[Dict[str, Any]] def format_openai_result(flagged: bool, categories: Dict[str, bool], scores: Dict[str, float]): return { "flagged": flagged, "categories": categories, "category_scores": scores } def classify_text_detoxify(text: str): predictions = MODELS['detoxify-multilingual'].predict(text) categories = { "hate": predictions['identity_attack'] > 0.5 or predictions['toxicity'] > 0.7, "hate/threatening": predictions['threat'] > 0.5, "harassment": predictions['insult'] > 0.5, "harassment/threatening": predictions['threat'] > 0.5, "self-harm": predictions['severe_toxicity'] > 0.6, "sexual": predictions['sexual_explicit'] > 0.5, "sexual/minors": False, "violence": predictions['toxicity'] > 0.8, "violence/graphic": predictions['severe_toxicity'] > 0.8, } scores = { "hate": float(max(predictions.get('identity_attack', 0), predictions.get('toxicity', 0))), "hate/threatening": float(predictions.get('threat', 0)), "harassment": float(predictions.get('insult', 0)), "harassment/threatening": float(predictions.get('threat', 0)), "self-harm": float(predictions.get('severe_toxicity', 0)), "sexual": float(predictions.get('sexual_explicit', 0)), "sexual/minors": 0.0, "violence": float(predictions.get('toxicity', 0)), "violence/graphic": float(predictions.get('severe_toxicity', 0)), } flagged = any(categories.values()) return format_openai_result(flagged, categories, scores) def process_image(image_data: bytes) -> np.ndarray: image = Image.open(io.BytesIO(image_data)).convert("RGB") image = image.resize((224, 224)) image_array = np.asarray(image) normalized_image_array = (image_array.astype(np.float32) / 127.5) - 1 return np.expand_dims(normalized_image_array, axis=0) def classify_image(image_data: bytes): model, labels = MODELS['nsfw-image-classifier'] processed_image = process_image(image_data) prediction = model.predict(processed_image, verbose=0) scores = {label: float(score) for label, score in zip(labels, prediction[0])} is_nsfw = scores.get('nsfw', 0.0) > 0.7 categories = { "hate": False, "hate/threatening": False, "harassment": False, "harassment/threatening": False, "self-harm": False, "sexual": is_nsfw, "sexual/minors": is_nsfw, "violence": False, "violence/graphic": is_nsfw, } category_scores = { "hate": 0.0, "hate/threatening": 0.0, "harassment": 0.0, "harassment/threatening": 0.0, "self-harm": 0.0, "sexual": scores.get('nsfw', 0.0), "sexual/minors": scores.get('nsfw', 0.0), "violence": 0.0, "violence/graphic": scores.get('nsfw', 0.0), } return format_openai_result(is_nsfw, categories, category_scores) def get_api_key(request: Request): api_key = request.headers.get("Authorization") if not api_key or not api_key.startswith("Bearer "): raise HTTPException(status_code=401, detail="API key is missing or invalid.") api_key = api_key.split(" ")[1] env_api_key = os.getenv("API_KEY") if not env_api_key or api_key != env_api_key: raise HTTPException(status_code=401, detail="Invalid API key.") return api_key @app.get("/", response_class=HTMLResponse) async def get_home(request: Request): return templates.TemplateResponse("index.html", {"request": request}) @app.get("/v1/metrics", response_class=JSONResponse) async def get_metrics(api_key: str = Depends(get_api_key)): avg_time = sum(request_times) / len(request_times) if request_times else 0 return { "concurrent_requests": concurrent_requests, "average_response_time_ms_last_100": avg_time * 1000, "tracked_request_count": len(request_times) } @app.post("/v1/moderations", response_model=ModerationResponse) async def moderate_content( request: ModerationRequest, api_key: str = Depends(get_api_key) ): inputs = request.input if isinstance(inputs, str): inputs = [inputs] if len(inputs) > 10: raise HTTPException(status_code=400, detail="Maximum of 10 items per request is allowed.") results = [] for item in inputs: result = None if isinstance(item, str): result = classify_text_detoxify(item) elif isinstance(item, InputItem): if item.text: result = classify_text_detoxify(item.text) elif item.image_url: try: response = requests.get(item.image_url, stream=True, timeout=10) response.raise_for_status() image_bytes = response.content result = classify_image(image_bytes) except requests.RequestException as e: raise HTTPException(status_code=400, detail=f"Could not fetch image from URL: {e}") elif item.image_base64: try: image_bytes = base64.b64decode(item.image_base64) result = classify_image(image_bytes) except Exception as e: raise HTTPException(status_code=400, detail=f"Invalid base64 image data: {e}") if result: results.append(result) else: raise HTTPException(status_code=400, detail="Invalid input item format provided.") model_name = request.model if request.model != "auto" else "multimodal-moderator" response_data = { "id": f"modr-{uuid.uuid4().hex}", "model": model_name, "results": results, } return response_data with open("templates/index.html", "w") as f: f.write("""
Advanced, multilingual, and multimodal content analysis for text and images.