import os
import uuid
import time
import threading
import io
from datetime import datetime, timedelta
from collections import defaultdict, deque
from flask import Flask, request, jsonify, render_template
from detoxify import Detoxify
import numpy as np
import requests
from PIL import Image
from tensorflow.keras.models import load_model
app = Flask(__name__, static_folder='static', template_folder='templates')
app.logger.setLevel('INFO')
API_KEY = os.environ.get('API_KEY')
if not API_KEY:
raise ValueError("API_KEY environment variable not set.")
print("Loading Detoxify model for text moderation...")
detoxify_model = Detoxify('multilingual')
print("Detoxify model loaded successfully.")
MODEL_PATH = 'keras_model.h5'
LABELS_PATH = 'labels.txt'
image_model = None
image_labels = None
try:
print("Loading Teachable Machine model for image moderation...")
image_model = load_model(MODEL_PATH, compile=False)
with open(LABELS_PATH, 'r') as f:
image_labels = [line.strip().split(' ')[1] for line in f.readlines()]
print("Image moderation model loaded successfully.")
except Exception as e:
app.logger.warning(f"Could not load image moderation model. Image moderation will be disabled. Error: {e}")
image_model = None
image_labels = None
request_durations = deque(maxlen=100)
request_timestamps = deque(maxlen=1000)
daily_requests = defaultdict(int)
concurrent_requests = 0
concurrent_requests_lock = threading.Lock()
def is_url(string):
return isinstance(string, str) and string.strip().startswith(('http://', 'https://'))
def classify_image(image_bytes):
if not image_model or not image_labels:
raise RuntimeError("Image moderation model is not available.")
image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
image = image.resize((224, 224))
image_array = np.asarray(image)
normalized_image_array = (image_array.astype(np.float32) / 127.5) - 1
data = np.ndarray(shape=(1, 224, 224, 3), dtype=np.float32)
data[0] = normalized_image_array
prediction = image_model.predict(data)
scores = {label.lower(): float(score) for label, score in zip(image_labels, prediction[0])}
return scores
def transform_text_predictions(prediction_dict):
category_keys = [
"toxicity", "severe_toxicity", "obscene", "threat",
"insult", "identity_attack", "sexual_explicit"
]
scores = {key: float(prediction_dict.get(key, 0.0)) for key in category_keys}
threshold = 0.5
categories = {key: (scores[key] > threshold) for key in category_keys}
flagged = any(categories.values())
return flagged, categories, scores
def transform_image_predictions(prediction_dict):
nsfw_score = prediction_dict.get('nsfw', 0.0)
categories = {
"sexual": nsfw_score > 0.8,
"hate": False,
"harassment": False,
"self-harm": False,
"sexual/minors": nsfw_score > 0.9,
"hate/threatening": False,
"violence/graphic": False,
"self-harm/intent": False,
"self-harm/instructions": False,
"harassment/threatening": False,
"violence": False,
}
category_scores = {
"sexual": nsfw_score,
"hate": 0.0,
"harassment": 0.0,
"self-harm": 0.0,
"sexual/minors": nsfw_score,
"hate/threatening": 0.0,
"violence/graphic": 0.0,
"self-harm/intent": 0.0,
"self-harm/instructions": 0.0,
"harassment/threatening": 0.0,
"violence": 0.0,
}
flagged = any(categories.values())
return flagged, categories, category_scores
def track_request_metrics(start_time):
duration = time.time() - start_time
request_durations.append(duration)
request_timestamps.append(datetime.now())
today = datetime.now().strftime("%Y-%m-%d")
daily_requests[today] += 1
def get_performance_metrics():
with concurrent_requests_lock:
current_concurrent = concurrent_requests
avg_request_time = sum(request_durations) / len(request_durations) if request_durations else 0
peak_request_time = max(request_durations) if request_durations else 0
now = datetime.now()
one_minute_ago = now - timedelta(seconds=60)
requests_last_minute = sum(1 for ts in request_timestamps if ts > one_minute_ago)
today_requests = daily_requests.get(now.strftime("%Y-%m-%d"), 0)
last_7_days = []
for i in range(7):
date = (now - timedelta(days=i)).strftime("%Y-%m-%d")
last_7_days.append({
"date": date,
"requests": daily_requests.get(date, 0),
})
return {
"avg_request_time_ms": avg_request_time * 1000,
"peak_request_time_ms": peak_request_time * 1000,
"requests_per_minute": requests_last_minute,
"concurrent_requests": current_concurrent,
"today_requests": today_requests,
"last_7_days": last_7_days
}
@app.route('/')
def home():
return render_template('index.html')
@app.route('/v1/moderations', methods=['POST'])
def moderations():
global concurrent_requests
auth_header = request.headers.get('Authorization')
if not auth_header or not auth_header.startswith("Bearer ") or auth_header.split(" ")[1] != API_KEY:
return jsonify({"error": {"message": "Incorrect API key provided.", "type": "invalid_request_error", "code": "invalid_api_key"}}), 401
with concurrent_requests_lock:
concurrent_requests += 1
start_time = time.time()
try:
data = request.get_json()
if not data:
return jsonify({"error": "Invalid JSON body"}), 400
raw_input = data.get('input')
if raw_input is None:
return jsonify({"error": "'input' field is required"}), 400
inputs = [raw_input] if isinstance(raw_input, str) else raw_input
if not isinstance(inputs, list):
return jsonify({"error": "'input' must be a string or a list of strings/URLs"}), 400
results = []
texts_to_process = []
text_indices = []
for i, item in enumerate(inputs):
if is_url(item):
try:
response = requests.get(item, timeout=10)
response.raise_for_status()
image_scores = classify_image(response.content)
flagged, categories, category_scores = transform_image_predictions(image_scores)
results.append((i, {"flagged": flagged, "categories": categories, "category_scores": category_scores}))
except requests.RequestException as e:
results.append((i, {"error": f"Failed to download image: {e}"}))
except Exception as e:
results.append((i, {"error": f"Failed to process image: {e}"}))
elif isinstance(item, str):
texts_to_process.append(item)
text_indices.append(i)
else:
results.append((i, {"error": "Invalid input type. Must be a string or URL."}))
if texts_to_process:
text_predictions = detoxify_model.predict(texts_to_process)
for i, original_index in enumerate(text_indices):
single_prediction = {key: value[i] for key, value in text_predictions.items()}
flagged, categories, category_scores = transform_text_predictions(single_prediction)
results.append((original_index, {"flagged": flagged, "categories": categories, "category_scores": category_scores}))
results.sort(key=lambda x: x[0])
final_results = [res for _, res in results]
response_data = {
"id": "modr-" + uuid.uuid4().hex[:24],
"model": "smart-moderator-multimodal-v1",
"results": final_results
}
return jsonify(response_data)
except Exception as e:
app.logger.error(f"An error occurred: {e}", exc_info=True)
return jsonify({"error": "An internal server error occurred."}), 500
finally:
track_request_metrics(start_time)
with concurrent_requests_lock:
concurrent_requests -= 1
@app.route('/v1/metrics', methods=['GET'])
def metrics():
auth_header = request.headers.get('Authorization')
if not auth_header or not auth_header.startswith("Bearer ") or auth_header.split(" ")[1] != API_KEY:
return jsonify({"error": "Unauthorized"}), 401
return jsonify(get_performance_metrics())
def create_app_structure():
os.makedirs('templates', exist_ok=True)
os.makedirs('static', exist_ok=True)
index_html_content = r'''
Smart Moderator API
API Documentation
Endpoint
POST /v1/moderations
Headers
Authorization: Bearer YOUR_API_KEY Content-Type: application/json
Request Body
The `input` field can be a single string/URL or a list of strings/URLs.
{"input": "Text to moderate"}
Usage Example (cURL)
Text Moderation
curl -X POST https://nixaut-codelabs-smart-moderator.hf.space/v1/moderations \
-H "Authorization: Bearer YOUR_API_KEY" \
-H "Content-Type: application/json" \
-d '{"input": "You are stupid and I hate you."}'
Multimodal Moderation (Text + Image)
curl -X POST https://nixaut-codelabs-smart-moderator.hf.space/v1/moderations \
-H "Authorization: Bearer YOUR_API_KEY" \
-H "Content-Type: application/json" \
-d '{"input": [
"This is a perfectly normal sentence.",
"https://upload.wikimedia.org/wikipedia/commons/3/3f/Fronalpstock_big.jpg"
]}'
'''
index_path = os.path.join('templates', 'index.html')
if not os.path.exists(index_path):
with open(index_path, 'w', encoding='utf-8') as f:
f.write(index_html_content)
if __name__ == '__main__':
create_app_structure()
port = int(os.environ.get('PORT', 7860))
# For production, use a proper WSGI server like Gunicorn
# gunicorn --bind 0.0.0.0:7860 app:app
app.run(host='0.0.0.0', port=port, debug=False)