import torch from transformers import GPT2LMHeadModel, GPT2TokenizerFast from fastapi import FastAPI, HTTPException, Header from pydantic import BaseModel import asyncio from concurrent.futures import ThreadPoolExecutor from contextlib import asynccontextmanager from dotenv import dotenv_values # FastAPI instance app = FastAPI() executor = ThreadPoolExecutor(max_workers=20) # Load .env file env = dotenv_values(".env") EXPECTED_TOKEN = env.get("SECRET_TOKEN") # Global variables for model and tokenizer model, tokenizer = None, None # Function to verify token def verify_token(auth: str): if auth != f"Bearer {EXPECTED_TOKEN}": raise HTTPException(status_code=403, detail="Unauthorized") # Function to load model and tokenizer def load_model(): model_path = "./Ai-Text-Detector/model" weights_path = "./Ai-Text-Detector/model_weights.pth" tokenizer = GPT2TokenizerFast.from_pretrained(model_path) model = GPT2LMHeadModel.from_pretrained("gpt2") model.load_state_dict(torch.load(weights_path, map_location=torch.device("cpu"))) model.eval() # Set the model to evaluation mode return model, tokenizer @asynccontextmanager async def lifespan(app: FastAPI): global model, tokenizer model, tokenizer = load_model() yield # Attach the lifespan context manager app = FastAPI(lifespan=lifespan) # Request body for input data class TextInput(BaseModel): text: str # Sync function to classify text def classify_text_sync(sentence: str): inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True) input_ids = inputs["input_ids"] attention_mask = inputs["attention_mask"] with torch.no_grad(): outputs = model(input_ids, attention_mask=attention_mask, labels=input_ids) loss = outputs.loss perplexity = torch.exp(loss).item() if perplexity < 60: result = "AI-generated*" elif perplexity < 80: result = "Probably AI-generated*" else: result = "Human-written*" return result, perplexity # Async wrapper for text classification async def classify_text(sentence: str): loop = asyncio.get_event_loop() return await loop.run_in_executor(executor, classify_text_sync, sentence) # POST route to analyze text @app.post("/analyze") async def analyze_text(data: TextInput, authorization: str = Header(default="")): verify_token(authorization) # Token verification user_input = data.text.strip() if not user_input: raise HTTPException(status_code=400, detail="Text cannot be empty") result, perplexity = await classify_text(user_input) return { "result": result, "perplexity": round(perplexity, 2), } # Health check route @app.get("/health") async def health_check(): return {"status": "ok"} # Simple index route @app.get("/") def index(): return {"message": "It's an API"} # Start the app (run with uvicorn) if __name__ == "__main__": import uvicorn uvicorn.run("main:app", host="0.0.0.0", port=8000, workers=4)