Spaces:
Sleeping
Sleeping
| import os | |
| import time | |
| import threading | |
| import torch | |
| import base64 | |
| import io | |
| import requests | |
| import uuid | |
| import numpy as np | |
| from typing import List, Dict, Any, Optional, Union | |
| from fastapi import FastAPI, HTTPException, Depends, Request | |
| from fastapi.responses import HTMLResponse, JSONResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.templating import Jinja2Templates | |
| from pydantic import BaseModel, Field, field_validator | |
| from dotenv import load_dotenv | |
| from huggingface_hub import snapshot_download | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from collections import deque | |
| from PIL import Image | |
| from tensorflow.keras.models import load_model | |
| from urllib.request import urlretrieve | |
| import uvicorn | |
| import logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| load_dotenv() | |
| os.makedirs("templates", exist_ok=True) | |
| os.makedirs("static", exist_ok=True) | |
| os.makedirs("image_model", exist_ok=True) | |
| app = FastAPI( | |
| title="Multimodal AI Content Moderation API", | |
| description="An advanced, multilingual, and multimodal content moderation API.", | |
| version="1.0.0" | |
| ) | |
| request_times = deque(maxlen=100) | |
| concurrent_requests = 0 | |
| request_lock = threading.Lock() | |
| async def track_metrics(request: Request, call_next): | |
| global concurrent_requests | |
| with request_lock: | |
| concurrent_requests += 1 | |
| start_time = time.time() | |
| response = await call_next(request) | |
| process_time = time.time() - start_time | |
| request_times.append(process_time) | |
| with request_lock: | |
| concurrent_requests -= 1 | |
| return response | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| templates = Jinja2Templates(directory="templates") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| logger.info(f"Using device: {device}") | |
| def download_file(url, path): | |
| if not os.path.exists(path): | |
| logger.info(f"Downloading {os.path.basename(path)}...") | |
| urlretrieve(url, path) | |
| logger.info("Downloading and loading models...") | |
| MODELS = {} | |
| logger.info("Loading text moderation model: detoxify-multilingual") | |
| from detoxify import Detoxify | |
| MODELS['detoxify-multilingual'] = Detoxify('multilingual', device=device) | |
| logger.info("Detoxify model loaded.") | |
| GEMMA_REPO = "daniel-dona/gemma-3-270m-it" | |
| LOCAL_GEMMA_DIR = os.path.join(os.getcwd(), "gemma_model") | |
| os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1") | |
| def ensure_local_model(repo_id: str, local_dir: str) -> str: | |
| os.makedirs(local_dir, exist_ok=True) | |
| snapshot_download( | |
| repo_id=repo_id, | |
| local_dir=local_dir, | |
| local_dir_use_symlinks=False, | |
| resume_download=True, | |
| ) | |
| return local_dir | |
| logger.info("Loading text moderation model: gemma-3-270m-it") | |
| gemma_path = ensure_local_model(GEMMA_REPO, LOCAL_GEMMA_DIR) | |
| gemma_tokenizer = AutoTokenizer.from_pretrained(gemma_path, local_files_only=True) | |
| gemma_model = AutoModelForCausalLM.from_pretrained( | |
| gemma_path, | |
| local_files_only=True, | |
| torch_dtype=torch.float32, | |
| device_map=device | |
| ) | |
| gemma_model.eval() | |
| MODELS['gemma-3-270m-it'] = (gemma_model, gemma_tokenizer) | |
| logger.info("Gemma model loaded.") | |
| NSFW_MODEL_URL = "https://teachablemachine.withgoogle.com/models/gJOADmf_u/keras_model.h5" | |
| NSFW_LABELS_URL = "https://teachablemachine.withgoogle.com/models/gJOADmf_u/labels.txt" | |
| NSFW_MODEL_PATH = "image_model/keras_model.h5" | |
| NSFW_LABELS_PATH = "image_model/labels.txt" | |
| download_file(NSFW_MODEL_URL, NSFW_MODEL_PATH) | |
| download_file(NSFW_LABELS_URL, NSFW_LABELS_PATH) | |
| logger.info("Loading image moderation model: nsfw-image-classifier") | |
| nsfw_model = load_model(NSFW_MODEL_PATH, compile=False) | |
| with open(NSFW_LABELS_PATH, "r") as f: | |
| nsfw_labels = [line.strip().split(' ')[1] for line in f] | |
| MODELS['nsfw-image-classifier'] = (nsfw_model, nsfw_labels) | |
| logger.info("NSFW image model loaded.") | |
| class InputItem(BaseModel): | |
| text: Optional[str] = None | |
| image_url: Optional[str] = None | |
| image_base64: Optional[str] = None | |
| def check_one_field(cls, v, info): | |
| if sum(1 for value in info.data.values() if value is not None) > 1: | |
| raise ValueError("Only one of text, image_url, or image_base64 can be provided.") | |
| return v | |
| class ModerationRequest(BaseModel): | |
| input: Union[str, List[Union[str, InputItem]]] = Field(..., max_length=10) | |
| model: str = "auto" | |
| class ModerationResponse(BaseModel): | |
| id: str | |
| model: str | |
| results: List[Dict[str, Any]] | |
| def format_openai_result(flagged: bool, categories: Dict[str, bool], scores: Dict[str, float]): | |
| return { | |
| "flagged": flagged, | |
| "categories": categories, | |
| "category_scores": scores | |
| } | |
| def classify_text_detoxify(text: str): | |
| predictions = MODELS['detoxify-multilingual'].predict(text) | |
| categories = { | |
| "hate": predictions['identity_attack'] > 0.5 or predictions['toxicity'] > 0.7, | |
| "hate/threatening": predictions['threat'] > 0.5, | |
| "harassment": predictions['insult'] > 0.5, | |
| "harassment/threatening": predictions['threat'] > 0.5, | |
| "self-harm": predictions['severe_toxicity'] > 0.6, | |
| "sexual": predictions['sexual_explicit'] > 0.5, | |
| "sexual/minors": False, | |
| "violence": predictions['toxicity'] > 0.8, | |
| "violence/graphic": predictions['severe_toxicity'] > 0.8, | |
| } | |
| scores = { | |
| "hate": float(max(predictions.get('identity_attack', 0), predictions.get('toxicity', 0))), | |
| "hate/threatening": float(predictions.get('threat', 0)), | |
| "harassment": float(predictions.get('insult', 0)), | |
| "harassment/threatening": float(predictions.get('threat', 0)), | |
| "self-harm": float(predictions.get('severe_toxicity', 0)), | |
| "sexual": float(predictions.get('sexual_explicit', 0)), | |
| "sexual/minors": 0.0, | |
| "violence": float(predictions.get('toxicity', 0)), | |
| "violence/graphic": float(predictions.get('severe_toxicity', 0)), | |
| } | |
| flagged = any(categories.values()) | |
| return format_openai_result(flagged, categories, scores) | |
| def process_image(image_data: bytes) -> np.ndarray: | |
| image = Image.open(io.BytesIO(image_data)).convert("RGB") | |
| image = image.resize((224, 224)) | |
| image_array = np.asarray(image) | |
| normalized_image_array = (image_array.astype(np.float32) / 127.5) - 1 | |
| return np.expand_dims(normalized_image_array, axis=0) | |
| def classify_image(image_data: bytes): | |
| model, labels = MODELS['nsfw-image-classifier'] | |
| processed_image = process_image(image_data) | |
| prediction = model.predict(processed_image, verbose=0) | |
| scores = {label: float(score) for label, score in zip(labels, prediction[0])} | |
| is_nsfw = scores.get('nsfw', 0.0) > 0.7 | |
| categories = { | |
| "hate": False, "hate/threatening": False, "harassment": False, "harassment/threatening": False, | |
| "self-harm": False, "sexual": is_nsfw, "sexual/minors": is_nsfw, "violence": False, "violence/graphic": is_nsfw, | |
| } | |
| category_scores = { | |
| "hate": 0.0, "hate/threatening": 0.0, "harassment": 0.0, "harassment/threatening": 0.0, | |
| "self-harm": 0.0, "sexual": scores.get('nsfw', 0.0), "sexual/minors": scores.get('nsfw', 0.0), | |
| "violence": 0.0, "violence/graphic": scores.get('nsfw', 0.0), | |
| } | |
| return format_openai_result(is_nsfw, categories, category_scores) | |
| def get_api_key(request: Request): | |
| api_key = request.headers.get("Authorization") | |
| if not api_key or not api_key.startswith("Bearer "): | |
| raise HTTPException(status_code=401, detail="API key is missing or invalid.") | |
| api_key = api_key.split(" ")[1] | |
| env_api_key = os.getenv("API_KEY") | |
| if not env_api_key or api_key != env_api_key: | |
| raise HTTPException(status_code=401, detail="Invalid API key.") | |
| return api_key | |
| async def get_home(request: Request): | |
| return templates.TemplateResponse("index.html", {"request": request}) | |
| async def get_metrics(api_key: str = Depends(get_api_key)): | |
| avg_time = sum(request_times) / len(request_times) if request_times else 0 | |
| return { | |
| "concurrent_requests": concurrent_requests, | |
| "average_response_time_ms_last_100": avg_time * 1000, | |
| "tracked_request_count": len(request_times) | |
| } | |
| async def moderate_content( | |
| request: ModerationRequest, | |
| api_key: str = Depends(get_api_key) | |
| ): | |
| inputs = request.input | |
| if isinstance(inputs, str): | |
| inputs = [inputs] | |
| if len(inputs) > 10: | |
| raise HTTPException(status_code=400, detail="Maximum of 10 items per request is allowed.") | |
| results = [] | |
| for item in inputs: | |
| result = None | |
| if isinstance(item, str): | |
| result = classify_text_detoxify(item) | |
| elif isinstance(item, InputItem): | |
| if item.text: | |
| result = classify_text_detoxify(item.text) | |
| elif item.image_url: | |
| try: | |
| response = requests.get(item.image_url, stream=True, timeout=10) | |
| response.raise_for_status() | |
| image_bytes = response.content | |
| result = classify_image(image_bytes) | |
| except requests.RequestException as e: | |
| raise HTTPException(status_code=400, detail=f"Could not fetch image from URL: {e}") | |
| elif item.image_base64: | |
| try: | |
| image_bytes = base64.b64decode(item.image_base64) | |
| result = classify_image(image_bytes) | |
| except Exception as e: | |
| raise HTTPException(status_code=400, detail=f"Invalid base64 image data: {e}") | |
| if result: | |
| results.append(result) | |
| else: | |
| raise HTTPException(status_code=400, detail="Invalid input item format provided.") | |
| model_name = request.model if request.model != "auto" else "multimodal-moderator" | |
| response_data = { | |
| "id": f"modr-{uuid.uuid4().hex}", | |
| "model": model_name, | |
| "results": results, | |
| } | |
| return response_data | |
| with open("templates/index.html", "w") as f: | |
| f.write(""" | |
| <!DOCTYPE html> | |
| <html lang="en"> | |
| <head> | |
| <meta charset="UTF-8"> | |
| <meta name="viewport" content="width=device-width, initial-scale=1.0"> | |
| <title>Multimodal AI Content Moderator</title> | |
| <script src="https://cdn.tailwindcss.com"></script> | |
| <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.0/css/all.min.css"> | |
| <style> | |
| .gradient-bg { background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); } | |
| .glass-effect { | |
| background: rgba(255, 255, 255, 0.1); | |
| backdrop-filter: blur(10px); | |
| border-radius: 10px; | |
| border: 1px solid rgba(255, 255, 255, 0.2); | |
| } | |
| </style> | |
| </head> | |
| <body class="min-h-screen gradient-bg text-white font-sans"> | |
| <div class="container mx-auto px-4 py-8"> | |
| <header class="text-center mb-10"> | |
| <h1 class="text-4xl md:text-5xl font-bold mb-4">Multimodal AI Content Moderator</h1> | |
| <p class="text-xl text-gray-200 max-w-3xl mx-auto"> | |
| Advanced, multilingual, and multimodal content analysis for text and images. | |
| </p> | |
| </header> | |
| <main class="max-w-6xl mx-auto"> | |
| <div class="grid grid-cols-1 lg:grid-cols-5 gap-8"> | |
| <div class="lg:col-span-2"> | |
| <div class="glass-effect p-6 rounded-xl h-full flex flex-col"> | |
| <h2 class="text-2xl font-bold mb-4 flex items-center"> | |
| <i class="fas fa-cogs mr-3"></i>Configuration & Status | |
| </h2> | |
| <div class="mb-4"> | |
| <label class="block text-sm font-medium mb-2">API Key</label> | |
| <input type="password" id="apiKey" placeholder="Enter your API key" | |
| class="w-full px-4 py-3 rounded-lg bg-white/10 border border-white/20 focus:outline-none focus:ring-2 focus:ring-indigo-400 text-white"> | |
| </div> | |
| <div class="mt-4 border-t border-white/20 pt-4"> | |
| <h3 class="text-lg font-semibold mb-3">Server Metrics</h3> | |
| <div class="space-y-3 text-sm"> | |
| <div class="flex justify-between"><span>Concurrent Requests:</span> <span id="concurrentRequests" class="font-mono">0</span></div> | |
| <div class="flex justify-between"><span>Avg. Response (last 100):</span> <span id="avgResponseTime" class="font-mono">0.00 ms</span></div> | |
| </div> | |
| </div> | |
| <div class="mt-auto pt-4"> | |
| <h3 class="text-lg font-semibold mb-2">API Endpoint</h3> | |
| <div class="bg-black/20 p-3 rounded-lg text-xs font-mono"> | |
| POST /v1/moderations | |
| </div> | |
| </div> | |
| </div> | |
| </div> | |
| <div class="lg:col-span-3"> | |
| <div class="glass-effect p-6 rounded-xl"> | |
| <h2 class="text-2xl font-bold mb-4 flex items-center"> | |
| <i class="fas fa-vial mr-3"></i>Live Tester | |
| </h2> | |
| <div id="input-container" class="space-y-3 mb-4"> | |
| <div class="input-item"> | |
| <textarea name="text" rows="2" placeholder="Enter text to analyze..." class="w-full p-2 rounded bg-white/10 border border-white/20 focus:outline-none focus:ring-2 focus:ring-indigo-400"></textarea> | |
| </div> | |
| </div> | |
| <div class="flex space-x-2 mb-6"> | |
| <button id="add-text" class="text-sm bg-white/10 hover:bg-white/20 py-1 px-3 rounded"><i class="fas fa-plus mr-1"></i> Text</button> | |
| <button id="add-image-url" class="text-sm bg-white/10 hover:bg-white/20 py-1 px-3 rounded"><i class="fas fa-link mr-1"></i> Image URL</button> | |
| <button id="add-image-file" class="text-sm bg-white/10 hover:bg-white/20 py-1 px-3 rounded"><i class="fas fa-upload mr-1"></i> Image File</button> | |
| </div> | |
| <input type="file" id="image-file-input" class="hidden" accept="image/*"> | |
| <button id="analyzeBtn" class="w-full bg-indigo-600 hover:bg-indigo-700 text-white font-bold py-3 px-6 rounded-lg transition duration-300"> | |
| <i class="fas fa-search mr-2"></i> Analyze Content | |
| </button> | |
| </div> | |
| </div> | |
| </div> | |
| <div id="resultsSection" class="mt-8 hidden"> | |
| <h3 class="text-xl font-bold mb-4">Analysis Results</h3> | |
| <div id="resultsContainer" class="space-y-4"></div> | |
| </div> | |
| </main> | |
| </div> | |
| <script> | |
| const apiKeyInput = document.getElementById('apiKey'); | |
| const inputContainer = document.getElementById('input-container'); | |
| const analyzeBtn = document.getElementById('analyzeBtn'); | |
| const resultsSection = document.getElementById('resultsSection'); | |
| const resultsContainer = document.getElementById('resultsContainer'); | |
| const concurrentRequestsEl = document.getElementById('concurrentRequests'); | |
| const avgResponseTimeEl = document.getElementById('avgResponseTime'); | |
| const imageFileInput = document.getElementById('image-file-input'); | |
| document.getElementById('add-text').addEventListener('click', () => addInput('text')); | |
| document.getElementById('add-image-url').addEventListener('click', () => addInput('image_url')); | |
| document.getElementById('add-image-file').addEventListener('click', () => imageFileInput.click()); | |
| imageFileInput.addEventListener('change', (event) => { | |
| if (event.target.files && event.target.files[0]) { | |
| const file = event.target.files[0]; | |
| const reader = new FileReader(); | |
| reader.onload = (e) => { | |
| addInput('image_base64', e.target.result); | |
| }; | |
| reader.readAsDataURL(file); | |
| } | |
| }); | |
| function addInput(type, value = '') { | |
| if (inputContainer.children.length >= 10) { | |
| alert('Maximum of 10 items per request.'); | |
| return; | |
| } | |
| const itemDiv = document.createElement('div'); | |
| itemDiv.className = 'input-item relative'; | |
| let inputHtml = ''; | |
| if (type === 'text') { | |
| inputHtml = `<textarea name="text" rows="2" placeholder="Enter text..." class="w-full p-2 rounded bg-white/10 border border-white/20 focus:outline-none focus:ring-2 focus:ring-indigo-400">${value}</textarea>`; | |
| } else if (type === 'image_url') { | |
| inputHtml = `<input type="text" name="image_url" placeholder="Enter image URL..." value="${value}" class="w-full p-2 rounded bg-white/10 border border-white/20 focus:outline-none focus:ring-2 focus:ring-indigo-400">`; | |
| } else if (type === 'image_base64') { | |
| inputHtml = ` | |
| <div class="flex items-center space-x-2 p-2 rounded bg-white/10 border border-white/20"> | |
| <img src="${value}" class="h-10 w-10 object-cover rounded"> | |
| <span class="text-sm truncate">Image File Uploaded</span> | |
| <input type="hidden" name="image_base64" value="${value.split(',')[1]}"> | |
| </div> | |
| `; | |
| } | |
| const removeBtn = `<button class="absolute -top-1 -right-1 text-red-400 hover:text-red-200 bg-gray-800 rounded-full h-5 w-5 flex items-center justify-center text-xs" onclick="this.parentElement.remove()"><i class="fas fa-times"></i></button>`; | |
| itemDiv.innerHTML = inputHtml + removeBtn; | |
| inputContainer.appendChild(itemDiv); | |
| } | |
| analyzeBtn.addEventListener('click', async () => { | |
| const apiKey = apiKeyInput.value.trim(); | |
| if (!apiKey) { | |
| alert('Please enter your API key.'); | |
| return; | |
| } | |
| const inputs = []; | |
| document.querySelectorAll('.input-item').forEach(item => { | |
| const text = item.querySelector('textarea[name="text"]'); | |
| const imageUrl = item.querySelector('input[name="image_url"]'); | |
| const imageBase64 = item.querySelector('input[name="image_base64"]'); | |
| if (text && text.value.trim()) inputs.push({ text: text.value.trim() }); | |
| if (imageUrl && imageUrl.value.trim()) inputs.push({ image_url: imageUrl.value.trim() }); | |
| if (imageBase64 && imageBase64.value) inputs.push({ image_base64: imageBase64.value }); | |
| }); | |
| if (inputs.length === 0) { | |
| alert('Please add at least one item to analyze.'); | |
| return; | |
| } | |
| analyzeBtn.disabled = true; | |
| analyzeBtn.innerHTML = '<i class="fas fa-spinner fa-spin mr-2"></i> Analyzing...'; | |
| try { | |
| const response = await fetch('/v1/moderations', { | |
| method: 'POST', | |
| headers: { | |
| 'Content-Type': 'application/json', | |
| 'Authorization': `Bearer ${apiKey}` | |
| }, | |
| body: JSON.stringify({ input: inputs }) | |
| }); | |
| const data = await response.json(); | |
| if (!response.ok) { | |
| throw new Error(data.detail || 'An error occurred.'); | |
| } | |
| displayResults(data.results); | |
| } catch (error) { | |
| alert(`Error: ${error.message}`); | |
| resultsSection.classList.add('hidden'); | |
| } finally { | |
| analyzeBtn.disabled = false; | |
| analyzeBtn.innerHTML = '<i class="fas fa-search mr-2"></i> Analyze Content'; | |
| } | |
| }); | |
| function displayResults(results) { | |
| resultsContainer.innerHTML = ''; | |
| results.forEach((result, index) => { | |
| const flagged = result.flagged; | |
| const card = document.createElement('div'); | |
| card.className = `glass-effect p-4 rounded-lg border-l-4 ${flagged ? 'border-red-400' : 'border-green-400'}`; | |
| let flaggedCategories = Object.entries(result.categories) | |
| .filter(([_, value]) => value === true) | |
| .map(([key]) => key) | |
| .join(', '); | |
| let scoresHtml = Object.entries(result.category_scores).map(([key, score]) => ` | |
| <div class="flex justify-between text-xs my-1"> | |
| <span>${key.replace(/_/g, ' ')}</span> | |
| <span class="font-mono">${(score * 100).toFixed(2)}%</span> | |
| </div> | |
| <div class="w-full bg-white/10 rounded-full h-1.5"> | |
| <div class="h-1.5 rounded-full ${score > 0.5 ? 'bg-red-400' : 'bg-green-400'}" style="width: ${score * 100}%"></div> | |
| </div> | |
| `).join(''); | |
| card.innerHTML = ` | |
| <div class="flex justify-between items-center mb-2"> | |
| <h4 class="font-bold">Item ${index + 1} - ${flagged ? 'FLAGGED' : 'SAFE'}</h4> | |
| ${flagged ? `<span class="text-xs text-red-300">${flaggedCategories}</span>` : ''} | |
| </div> | |
| <div>${scoresHtml}</div> | |
| `; | |
| resultsContainer.appendChild(card); | |
| }); | |
| resultsSection.classList.remove('hidden'); | |
| } | |
| async function fetchMetrics() { | |
| const apiKey = apiKeyInput.value.trim(); | |
| if (!apiKey) return; | |
| try { | |
| const response = await fetch('/v1/metrics', { | |
| headers: { 'Authorization': `Bearer ${apiKey}` } | |
| }); | |
| if (response.ok) { | |
| const data = await response.json(); | |
| concurrentRequestsEl.textContent = data.concurrent_requests; | |
| avgResponseTimeEl.textContent = `${data.average_response_time_ms_last_100.toFixed(2)} ms`; | |
| } | |
| } catch (error) { | |
| console.error("Failed to fetch metrics"); | |
| } | |
| } | |
| setInterval(fetchMetrics, 3000); | |
| </script> | |
| </body> | |
| </html> | |
| """) | |
| if __name__ == "__main__": | |
| logger.info("Starting AI Content Moderator API...") | |
| uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", 7860))) |