| import torch |
| from transformers import GPT2LMHeadModel, GPT2TokenizerFast |
| from fastapi import FastAPI, HTTPException, Header |
| from pydantic import BaseModel |
| import asyncio |
| from concurrent.futures import ThreadPoolExecutor |
| from contextlib import asynccontextmanager |
| from dotenv import dotenv_values |
|
|
| |
| app = FastAPI() |
| executor = ThreadPoolExecutor(max_workers=20) |
|
|
| |
| env = dotenv_values(".env") |
| EXPECTED_TOKEN = env.get("SECRET_TOKEN") |
|
|
| |
| model, tokenizer = None, None |
|
|
| |
|
|
|
|
| def verify_token(auth: str): |
| if auth != f"Bearer {EXPECTED_TOKEN}": |
| raise HTTPException(status_code=403, detail="Unauthorized") |
|
|
|
|
| |
|
|
|
|
| def load_model(): |
| model_path = "./Ai-Text-Detector/model" |
| weights_path = "./Ai-Text-Detector/model_weights.pth" |
| tokenizer = GPT2TokenizerFast.from_pretrained(model_path) |
| model = GPT2LMHeadModel.from_pretrained("gpt2") |
| model.load_state_dict(torch.load(weights_path, map_location=torch.device("cpu"))) |
| model.eval() |
| return model, tokenizer |
|
|
|
|
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| global model, tokenizer |
| model, tokenizer = load_model() |
| yield |
|
|
|
|
| |
| app = FastAPI(lifespan=lifespan) |
|
|
| |
|
|
|
|
| class TextInput(BaseModel): |
| text: str |
|
|
|
|
| |
|
|
|
|
| def classify_text_sync(sentence: str): |
| inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True) |
| input_ids = inputs["input_ids"] |
| attention_mask = inputs["attention_mask"] |
|
|
| with torch.no_grad(): |
| outputs = model(input_ids, attention_mask=attention_mask, labels=input_ids) |
| loss = outputs.loss |
| perplexity = torch.exp(loss).item() |
|
|
| if perplexity < 60: |
| result = "AI-generated*" |
| elif perplexity < 80: |
| result = "Probably AI-generated*" |
| else: |
| result = "Human-written*" |
|
|
| return result, perplexity |
|
|
|
|
| |
|
|
|
|
| async def classify_text(sentence: str): |
| loop = asyncio.get_event_loop() |
| return await loop.run_in_executor(executor, classify_text_sync, sentence) |
|
|
|
|
| |
|
|
|
|
| @app.post("/analyze") |
| async def analyze_text(data: TextInput, authorization: str = Header(default="")): |
| verify_token(authorization) |
| user_input = data.text.strip() |
|
|
| if not user_input: |
| raise HTTPException(status_code=400, detail="Text cannot be empty") |
|
|
| result, perplexity = await classify_text(user_input) |
|
|
| return { |
| "result": result, |
| "perplexity": round(perplexity, 2), |
| } |
|
|
|
|
| |
|
|
|
|
| @app.get("/health") |
| async def health_check(): |
| return {"status": "ok"} |
|
|
|
|
| |
|
|
|
|
| @app.get("/") |
| def index(): |
| return {"message": "It's an API"} |
|
|
|
|