| | import torch |
| | from transformers import GPT2LMHeadModel, GPT2TokenizerFast |
| | from fastapi import FastAPI, HTTPException, Header |
| | from pydantic import BaseModel |
| | import asyncio |
| | from concurrent.futures import ThreadPoolExecutor |
| | from contextlib import asynccontextmanager |
| | from dotenv import dotenv_values |
| |
|
| | |
| | app = FastAPI() |
| | executor = ThreadPoolExecutor(max_workers=20) |
| |
|
| | |
| | env = dotenv_values(".env") |
| | EXPECTED_TOKEN = env.get("SECRET_TOKEN") |
| |
|
| | |
| | model, tokenizer = None, None |
| |
|
| | |
| |
|
| |
|
| | def verify_token(auth: str): |
| | if auth != f"Bearer {EXPECTED_TOKEN}": |
| | raise HTTPException(status_code=403, detail="Unauthorized") |
| |
|
| |
|
| | |
| |
|
| |
|
| | def load_model(): |
| | model_path = "./Ai-Text-Detector/model" |
| | weights_path = "./Ai-Text-Detector/model_weights.pth" |
| | tokenizer = GPT2TokenizerFast.from_pretrained(model_path) |
| | model = GPT2LMHeadModel.from_pretrained("gpt2") |
| | model.load_state_dict(torch.load(weights_path, map_location=torch.device("cpu"))) |
| | model.eval() |
| | return model, tokenizer |
| |
|
| |
|
| | @asynccontextmanager |
| | async def lifespan(app: FastAPI): |
| | global model, tokenizer |
| | model, tokenizer = load_model() |
| | yield |
| |
|
| |
|
| | |
| | app = FastAPI(lifespan=lifespan) |
| |
|
| | |
| |
|
| |
|
| | class TextInput(BaseModel): |
| | text: str |
| |
|
| |
|
| | |
| |
|
| |
|
| | def classify_text_sync(sentence: str): |
| | inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True) |
| | input_ids = inputs["input_ids"] |
| | attention_mask = inputs["attention_mask"] |
| |
|
| | with torch.no_grad(): |
| | outputs = model(input_ids, attention_mask=attention_mask, labels=input_ids) |
| | loss = outputs.loss |
| | perplexity = torch.exp(loss).item() |
| |
|
| | if perplexity < 60: |
| | result = "AI-generated*" |
| | elif perplexity < 80: |
| | result = "Probably AI-generated*" |
| | else: |
| | result = "Human-written*" |
| |
|
| | return result, perplexity |
| |
|
| |
|
| | |
| |
|
| |
|
| | async def classify_text(sentence: str): |
| | loop = asyncio.get_event_loop() |
| | return await loop.run_in_executor(executor, classify_text_sync, sentence) |
| |
|
| |
|
| | |
| |
|
| |
|
| | @app.post("/analyze") |
| | async def analyze_text(data: TextInput, authorization: str = Header(default="")): |
| | verify_token(authorization) |
| | user_input = data.text.strip() |
| |
|
| | if not user_input: |
| | raise HTTPException(status_code=400, detail="Text cannot be empty") |
| |
|
| | result, perplexity = await classify_text(user_input) |
| |
|
| | return { |
| | "result": result, |
| | "perplexity": round(perplexity, 2), |
| | } |
| |
|
| |
|
| | |
| |
|
| |
|
| | @app.get("/health") |
| | async def health_check(): |
| | return {"status": "ok"} |
| |
|
| |
|
| | |
| |
|
| |
|
| | @app.get("/") |
| | def index(): |
| | return {"message": "It's an API"} |
| |
|
| |
|
| | |
| | if __name__ == "__main__": |
| | import uvicorn |
| |
|
| | uvicorn.run("main:app", host="0.0.0.0", port=8000, workers=4) |
| |
|