File size: 2,530 Bytes
5792ae0
f588c6c
8adf4f2
5792ae0
 
da25d43
5792ae0
8adf4f2
5792ae0
 
da25d43
5792ae0
 
 
 
 
d60213d
2cca8fd
da25d43
 
 
 
 
 
 
 
 
2cca8fd
5792ae0
da25d43
5792ae0
 
 
 
 
 
8adf4f2
5792ae0
 
8adf4f2
5792ae0
 
 
8adf4f2
 
5792ae0
 
 
 
 
 
 
 
 
 
da25d43
5792ae0
da25d43
5792ae0
da25d43
5792ae0
 
 
 
 
8adf4f2
5792ae0
 
 
 
da25d43
 
 
5792ae0
 
 
 
 
 
 
 
 
 
 
 
 
6386bd6
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import torch
from transformers import GPT2LMHeadModel, GPT2TokenizerFast, GPT2Config
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from contextlib import asynccontextmanager
import asyncio

# FastAPI app instance
app = FastAPI()

# Global model and tokenizer variables
model, tokenizer = None, None

# Function to load model and tokenizer
def load_model():
    model_path = "./Ai-Text-Detector/model"
    weights_path = "./Ai-Text-Detector/model_weights.pth"

    try:
        tokenizer = GPT2TokenizerFast.from_pretrained(model_path)
        config = GPT2Config.from_pretrained(model_path)
        model = GPT2LMHeadModel(config)
        model.load_state_dict(torch.load(weights_path, map_location=torch.device("cpu")))
        model.eval()  # Set model to evaluation mode
    except Exception as e:
        raise RuntimeError(f"Error loading model: {str(e)}")

    return model, tokenizer

# Load model on app startup
@asynccontextmanager
async def lifespan(app: FastAPI):
    global model, tokenizer
    model, tokenizer = load_model()
    yield

# Attach startup loader
app = FastAPI(lifespan=lifespan)

# Input schema
class TextInput(BaseModel):
    text: str

# Sync text classification
def classify_text(sentence: str):
    inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True)
    input_ids = inputs["input_ids"]
    attention_mask = inputs["attention_mask"]

    with torch.no_grad():
        outputs = model(input_ids, attention_mask=attention_mask, labels=input_ids)
        loss = outputs.loss
        perplexity = torch.exp(loss).item()

    if perplexity < 60:
        result = "AI-generated"
    elif perplexity < 80:
        result = "Probably AI-generated"
    else:
        result = "Human-written"

    return result, perplexity

# POST route to analyze text
@app.post("/analyze")
async def analyze_text(data: TextInput):
    user_input = data.text.strip()
    if not user_input:
        raise HTTPException(status_code=400, detail="Text cannot be empty")

    # Run classification asynchronously to prevent blocking
    result, perplexity = await asyncio.to_thread(classify_text, user_input)
    
    return {
        "result": result,
        "perplexity": round(perplexity, 2),
    }

# Health check route
@app.get("/health")
async def health_check():
    return {"status": "ok"}

# Simple index route
@app.get("/")
def index():
    return {
        "message": "FastAPI API is up.",
        "try": "/docs to test the API.",
        "status": "OK"
    }