Spaces:
Sleeping
Sleeping
File size: 2,530 Bytes
5792ae0 f588c6c 8adf4f2 5792ae0 da25d43 5792ae0 8adf4f2 5792ae0 da25d43 5792ae0 d60213d 2cca8fd da25d43 2cca8fd 5792ae0 da25d43 5792ae0 8adf4f2 5792ae0 8adf4f2 5792ae0 8adf4f2 5792ae0 da25d43 5792ae0 da25d43 5792ae0 da25d43 5792ae0 8adf4f2 5792ae0 da25d43 5792ae0 6386bd6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
import torch
from transformers import GPT2LMHeadModel, GPT2TokenizerFast, GPT2Config
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from contextlib import asynccontextmanager
import asyncio
# FastAPI app instance
app = FastAPI()
# Global model and tokenizer variables
model, tokenizer = None, None
# Function to load model and tokenizer
def load_model():
model_path = "./Ai-Text-Detector/model"
weights_path = "./Ai-Text-Detector/model_weights.pth"
try:
tokenizer = GPT2TokenizerFast.from_pretrained(model_path)
config = GPT2Config.from_pretrained(model_path)
model = GPT2LMHeadModel(config)
model.load_state_dict(torch.load(weights_path, map_location=torch.device("cpu")))
model.eval() # Set model to evaluation mode
except Exception as e:
raise RuntimeError(f"Error loading model: {str(e)}")
return model, tokenizer
# Load model on app startup
@asynccontextmanager
async def lifespan(app: FastAPI):
global model, tokenizer
model, tokenizer = load_model()
yield
# Attach startup loader
app = FastAPI(lifespan=lifespan)
# Input schema
class TextInput(BaseModel):
text: str
# Sync text classification
def classify_text(sentence: str):
inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True)
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
with torch.no_grad():
outputs = model(input_ids, attention_mask=attention_mask, labels=input_ids)
loss = outputs.loss
perplexity = torch.exp(loss).item()
if perplexity < 60:
result = "AI-generated"
elif perplexity < 80:
result = "Probably AI-generated"
else:
result = "Human-written"
return result, perplexity
# POST route to analyze text
@app.post("/analyze")
async def analyze_text(data: TextInput):
user_input = data.text.strip()
if not user_input:
raise HTTPException(status_code=400, detail="Text cannot be empty")
# Run classification asynchronously to prevent blocking
result, perplexity = await asyncio.to_thread(classify_text, user_input)
return {
"result": result,
"perplexity": round(perplexity, 2),
}
# Health check route
@app.get("/health")
async def health_check():
return {"status": "ok"}
# Simple index route
@app.get("/")
def index():
return {
"message": "FastAPI API is up.",
"try": "/docs to test the API.",
"status": "OK"
}
|