|
|
import torch |
|
|
from transformers import GPT2LMHeadModel, GPT2TokenizerFast |
|
|
from fastapi import FastAPI, HTTPException, Header |
|
|
from pydantic import BaseModel |
|
|
import asyncio |
|
|
from concurrent.futures import ThreadPoolExecutor |
|
|
from contextlib import asynccontextmanager |
|
|
from dotenv import dotenv_values |
|
|
|
|
|
|
|
|
app = FastAPI() |
|
|
executor = ThreadPoolExecutor(max_workers=20) |
|
|
|
|
|
|
|
|
env = dotenv_values(".env") |
|
|
EXPECTED_TOKEN = env.get("SECRET_TOKEN") |
|
|
|
|
|
|
|
|
model, tokenizer = None, None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def verify_token(auth: str): |
|
|
if auth != f"Bearer {EXPECTED_TOKEN}": |
|
|
raise HTTPException(status_code=403, detail="Unauthorized") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_model(): |
|
|
model_path = "./Ai-Text-Detector/model" |
|
|
weights_path = "./Ai-Text-Detector/model_weights.pth" |
|
|
tokenizer = GPT2TokenizerFast.from_pretrained(model_path) |
|
|
model = GPT2LMHeadModel.from_pretrained("gpt2") |
|
|
model.load_state_dict(torch.load(weights_path, map_location=torch.device("cpu"))) |
|
|
model.eval() |
|
|
return model, tokenizer |
|
|
|
|
|
|
|
|
@asynccontextmanager |
|
|
async def lifespan(app: FastAPI): |
|
|
global model, tokenizer |
|
|
model, tokenizer = load_model() |
|
|
yield |
|
|
|
|
|
|
|
|
|
|
|
app = FastAPI(lifespan=lifespan) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TextInput(BaseModel): |
|
|
text: str |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def classify_text_sync(sentence: str): |
|
|
inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True) |
|
|
input_ids = inputs["input_ids"] |
|
|
attention_mask = inputs["attention_mask"] |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(input_ids, attention_mask=attention_mask, labels=input_ids) |
|
|
loss = outputs.loss |
|
|
perplexity = torch.exp(loss).item() |
|
|
|
|
|
if perplexity < 60: |
|
|
result = "AI-generated*" |
|
|
elif perplexity < 80: |
|
|
result = "Probably AI-generated*" |
|
|
else: |
|
|
result = "Human-written*" |
|
|
|
|
|
return result, perplexity |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def classify_text(sentence: str): |
|
|
loop = asyncio.get_event_loop() |
|
|
return await loop.run_in_executor(executor, classify_text_sync, sentence) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/analyze") |
|
|
async def analyze_text(data: TextInput, authorization: str = Header(default="")): |
|
|
verify_token(authorization) |
|
|
user_input = data.text.strip() |
|
|
|
|
|
if not user_input: |
|
|
raise HTTPException(status_code=400, detail="Text cannot be empty") |
|
|
|
|
|
result, perplexity = await classify_text(user_input) |
|
|
|
|
|
return { |
|
|
"result": result, |
|
|
"perplexity": round(perplexity, 2), |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/health") |
|
|
async def health_check(): |
|
|
return {"status": "ok"} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/") |
|
|
def index(): |
|
|
return {"message": "It's an API"} |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
|
|
|
uvicorn.run("main:app", host="0.0.0.0", port=8000, workers=4) |
|
|
|