AI-API / app.py
Pujan Neupane
final
9216814
raw
history blame
3.07 kB
import torch
from transformers import GPT2LMHeadModel, GPT2TokenizerFast
from fastapi import FastAPI, HTTPException, Header
from pydantic import BaseModel
import asyncio
from concurrent.futures import ThreadPoolExecutor
from contextlib import asynccontextmanager
from dotenv import dotenv_values
# FastAPI instance
app = FastAPI()
executor = ThreadPoolExecutor(max_workers=20)
# Load .env file
env = dotenv_values(".env")
EXPECTED_TOKEN = env.get("SECRET_TOKEN")
# Global variables for model and tokenizer
model, tokenizer = None, None
# Function to verify token
def verify_token(auth: str):
if auth != f"Bearer {EXPECTED_TOKEN}":
raise HTTPException(status_code=403, detail="Unauthorized")
# Function to load model and tokenizer
def load_model():
model_path = "./Ai-Text-Detector/model"
weights_path = "./Ai-Text-Detector/model_weights.pth"
tokenizer = GPT2TokenizerFast.from_pretrained(model_path)
model = GPT2LMHeadModel.from_pretrained("gpt2")
model.load_state_dict(torch.load(weights_path, map_location=torch.device("cpu")))
model.eval() # Set the model to evaluation mode
return model, tokenizer
@asynccontextmanager
async def lifespan(app: FastAPI):
global model, tokenizer
model, tokenizer = load_model()
yield
# Attach the lifespan context manager
app = FastAPI(lifespan=lifespan)
# Request body for input data
class TextInput(BaseModel):
text: str
# Sync function to classify text
def classify_text_sync(sentence: str):
inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True)
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
with torch.no_grad():
outputs = model(input_ids, attention_mask=attention_mask, labels=input_ids)
loss = outputs.loss
perplexity = torch.exp(loss).item()
if perplexity < 60:
result = "AI-generated*"
elif perplexity < 80:
result = "Probably AI-generated*"
else:
result = "Human-written*"
return result, perplexity
# Async wrapper for text classification
async def classify_text(sentence: str):
loop = asyncio.get_event_loop()
return await loop.run_in_executor(executor, classify_text_sync, sentence)
# POST route to analyze text
@app.post("/analyze")
async def analyze_text(data: TextInput, authorization: str = Header(default="")):
verify_token(authorization) # Token verification
user_input = data.text.strip()
if not user_input:
raise HTTPException(status_code=400, detail="Text cannot be empty")
result, perplexity = await classify_text(user_input)
return {
"result": result,
"perplexity": round(perplexity, 2),
}
# Health check route
@app.get("/health")
async def health_check():
return {"status": "ok"}
# Simple index route
@app.get("/")
def index():
return {"message": "It's an API"}
# Start the app (run with uvicorn)
if __name__ == "__main__":
import uvicorn
uvicorn.run("main:app", host="0.0.0.0", port=8000, workers=4)