from __future__ import annotations import os import re import secrets from contextlib import asynccontextmanager from typing import Annotated import torch from fastapi import FastAPI, HTTPException, Security from fastapi.middleware.cors import CORSMiddleware from fastapi.security import APIKeyHeader from pydantic import BaseModel, Field from transformers import pipeline # ─── Config ──────────────────────────────────────────────────────────────────── MODEL_ID = "openai-community/roberta-base-openai-detector" # Read from HuggingFace Space secret (Settings → Variables and secrets) API_KEY = os.environ.get("API_KEY", "") if not API_KEY: raise RuntimeError( "API_KEY environment variable is not set. " "Add it in your HuggingFace Space → Settings → Variables and secrets." ) # Header scheme — clients send: X-API-Key: api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False) def verify_api_key(key: str | None = Security(api_key_header)) -> str: """Dependency: rejects requests with a missing or wrong API key.""" if not key or not secrets.compare_digest(key, API_KEY): raise HTTPException( status_code=401, detail="Invalid or missing API key. Pass it as the X-API-Key header.", ) return key # ─── Lifespan ────────────────────────────────────────────────────────────────── classifier = None @asynccontextmanager async def lifespan(app: FastAPI): global classifier print(f"Loading model {MODEL_ID} …") classifier = pipeline( "text-classification", model=MODEL_ID, device=0 if torch.cuda.is_available() else -1, ) print("Model ready.") yield # ─── App ─────────────────────────────────────────────────────────────────────── app = FastAPI( title="AI Text Detector API", description="Detects whether text is human-written or AI-generated. Requires X-API-Key header.", version="2.0.0", lifespan=lifespan, ) app.add_middleware( CORSMiddleware, allow_origins=["*"], # lock this down to your domain in production allow_methods=["POST", "GET"], allow_headers=["*"], ) # ─── Helpers ─────────────────────────────────────────────────────────────────── def split_into_chunks(text: str) -> list[str]: chunks: list[str] = [] paragraphs = [p.strip() for p in text.split("\n") if p.strip()] or [text.strip()] for para in paragraphs: sentences = re.split(r"(?<=[.!?])\s+", para) current = "" for sent in sentences: if len((current + " " + sent).split()) > 80: if current.strip(): chunks.append(current.strip()) current = sent else: current = (current + " " + sent).strip() if current.strip(): chunks.append(current.strip()) return chunks or [text.strip()] # ─── Schemas ─────────────────────────────────────────────────────────────────── class DetectRequest(BaseModel): text: Annotated[ str, Field( min_length=1, max_length=10_000, description="Text to analyse (max 10,000 characters)", ), ] class ChunkResult(BaseModel): text: str ai_probability: float human_probability: float label: str # "AI" | "Human" confidence: float class DetectResponse(BaseModel): label: str ai_probability: float human_probability: float confidence: float chunks: list[ChunkResult] total_chunks: int ai_chunks: int human_chunks: int # ─── Routes ──────────────────────────────────────────────────────────────────── @app.get("/", tags=["health"]) async def health(): """Public health-check — no API key required.""" return {"status": "ok", "model": MODEL_ID} @app.post( "/detect", response_model=DetectResponse, tags=["detection"], dependencies=[Security(verify_api_key)], ) async def detect(body: DetectRequest): if classifier is None: raise HTTPException(status_code=503, detail="Model not loaded yet — try again shortly.") chunks = split_into_chunks(body.text) raw = classifier(chunks, truncation=True, max_length=512, batch_size=8) chunk_results: list[ChunkResult] = [] ai_probs: list[float] = [] word_counts: list[int] = [] for chunk, res in zip(chunks, raw): ai_prob = res["score"] if res["label"] == "Fake" else 1.0 - res["score"] human_prob = 1.0 - ai_prob is_ai = ai_prob >= 0.5 label = "AI" if is_ai else "Human" conf = ai_prob if is_ai else human_prob chunk_results.append( ChunkResult( text=chunk, ai_probability=round(ai_prob, 4), human_probability=round(human_prob, 4), label=label, confidence=round(conf, 4), ) ) ai_probs.append(ai_prob) word_counts.append(len(chunk.split())) total_words = sum(word_counts) avg_ai = sum(p * w for p, w in zip(ai_probs, word_counts)) / total_words avg_human = 1.0 - avg_ai overall_label = "AI" if avg_ai >= 0.5 else "Human" overall_conf = avg_ai if overall_label == "AI" else avg_human ai_chunks = sum(1 for p in ai_probs if p >= 0.5) return DetectResponse( label=overall_label, ai_probability=round(avg_ai, 4), human_probability=round(avg_human, 4), confidence=round(overall_conf, 4), chunks=chunk_results, total_chunks=len(chunks), ai_chunks=ai_chunks, human_chunks=len(chunks) - ai_chunks, )