test-space / app.py
Pujan-Dev's picture
Update app.py
da25d43 verified
import torch
from transformers import GPT2LMHeadModel, GPT2TokenizerFast, GPT2Config
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from contextlib import asynccontextmanager
import asyncio
# FastAPI app instance
app = FastAPI()
# Global model and tokenizer variables
model, tokenizer = None, None
# Function to load model and tokenizer
def load_model():
model_path = "./Ai-Text-Detector/model"
weights_path = "./Ai-Text-Detector/model_weights.pth"
try:
tokenizer = GPT2TokenizerFast.from_pretrained(model_path)
config = GPT2Config.from_pretrained(model_path)
model = GPT2LMHeadModel(config)
model.load_state_dict(torch.load(weights_path, map_location=torch.device("cpu")))
model.eval() # Set model to evaluation mode
except Exception as e:
raise RuntimeError(f"Error loading model: {str(e)}")
return model, tokenizer
# Load model on app startup
@asynccontextmanager
async def lifespan(app: FastAPI):
global model, tokenizer
model, tokenizer = load_model()
yield
# Attach startup loader
app = FastAPI(lifespan=lifespan)
# Input schema
class TextInput(BaseModel):
text: str
# Sync text classification
def classify_text(sentence: str):
inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True)
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
with torch.no_grad():
outputs = model(input_ids, attention_mask=attention_mask, labels=input_ids)
loss = outputs.loss
perplexity = torch.exp(loss).item()
if perplexity < 60:
result = "AI-generated"
elif perplexity < 80:
result = "Probably AI-generated"
else:
result = "Human-written"
return result, perplexity
# POST route to analyze text
@app.post("/analyze")
async def analyze_text(data: TextInput):
user_input = data.text.strip()
if not user_input:
raise HTTPException(status_code=400, detail="Text cannot be empty")
# Run classification asynchronously to prevent blocking
result, perplexity = await asyncio.to_thread(classify_text, user_input)
return {
"result": result,
"perplexity": round(perplexity, 2),
}
# Health check route
@app.get("/health")
async def health_check():
return {"status": "ok"}
# Simple index route
@app.get("/")
def index():
return {
"message": "FastAPI API is up.",
"try": "/docs to test the API.",
"status": "OK"
}