File size: 3,070 Bytes
5792ae0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import torch
from transformers import GPT2LMHeadModel, GPT2TokenizerFast
from fastapi import FastAPI, HTTPException, Header
from pydantic import BaseModel
import asyncio
from concurrent.futures import ThreadPoolExecutor
from contextlib import asynccontextmanager
from dotenv import dotenv_values

# FastAPI instance
app = FastAPI()
executor = ThreadPoolExecutor(max_workers=20)

# Load .env file
env = dotenv_values(".env")
EXPECTED_TOKEN = env.get("SECRET_TOKEN")

# Global variables for model and tokenizer
model, tokenizer = None, None

# Function to verify token


def verify_token(auth: str):
    if auth != f"Bearer {EXPECTED_TOKEN}":
        raise HTTPException(status_code=403, detail="Unauthorized")


# Function to load model and tokenizer


def load_model():
    model_path = "./Ai-Text-Detector/model"
    weights_path = "./Ai-Text-Detector/model_weights.pth"
    tokenizer = GPT2TokenizerFast.from_pretrained(model_path)
    model = GPT2LMHeadModel.from_pretrained("gpt2")
    model.load_state_dict(torch.load(weights_path, map_location=torch.device("cpu")))
    model.eval()  # Set the model to evaluation mode
    return model, tokenizer


@asynccontextmanager
async def lifespan(app: FastAPI):
    global model, tokenizer
    model, tokenizer = load_model()
    yield


# Attach the lifespan context manager
app = FastAPI(lifespan=lifespan)

# Request body for input data


class TextInput(BaseModel):
    text: str


# Sync function to classify text


def classify_text_sync(sentence: str):
    inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True)
    input_ids = inputs["input_ids"]
    attention_mask = inputs["attention_mask"]

    with torch.no_grad():
        outputs = model(input_ids, attention_mask=attention_mask, labels=input_ids)
        loss = outputs.loss
        perplexity = torch.exp(loss).item()

    if perplexity < 60:
        result = "AI-generated*"
    elif perplexity < 80:
        result = "Probably AI-generated*"
    else:
        result = "Human-written*"

    return result, perplexity


# Async wrapper for text classification


async def classify_text(sentence: str):
    loop = asyncio.get_event_loop()
    return await loop.run_in_executor(executor, classify_text_sync, sentence)


# POST route to analyze text


@app.post("/analyze")
async def analyze_text(data: TextInput, authorization: str = Header(default="")):
    verify_token(authorization)  # Token verification
    user_input = data.text.strip()

    if not user_input:
        raise HTTPException(status_code=400, detail="Text cannot be empty")

    result, perplexity = await classify_text(user_input)

    return {
        "result": result,
        "perplexity": round(perplexity, 2),
    }


# Health check route


@app.get("/health")
async def health_check():
    return {"status": "ok"}


# Simple index route


@app.get("/")
def index():
    return {"message": "It's an API"}


# Start the app (run with uvicorn)
if __name__ == "__main__":
    import uvicorn

    uvicorn.run("main:app", host="0.0.0.0", port=8000, workers=4)