Spaces:
Sleeping
Sleeping
| import torch | |
| from transformers import GPT2LMHeadModel, GPT2TokenizerFast, GPT2Config | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from contextlib import asynccontextmanager | |
| import asyncio | |
| # FastAPI app instance | |
| app = FastAPI() | |
| # Global model and tokenizer variables | |
| model, tokenizer = None, None | |
| # Function to load model and tokenizer | |
| def load_model(): | |
| model_path = "./Ai-Text-Detector/model" | |
| weights_path = "./Ai-Text-Detector/model_weights.pth" | |
| try: | |
| tokenizer = GPT2TokenizerFast.from_pretrained(model_path) | |
| config = GPT2Config.from_pretrained(model_path) | |
| model = GPT2LMHeadModel(config) | |
| model.load_state_dict(torch.load(weights_path, map_location=torch.device("cpu"))) | |
| model.eval() # Set model to evaluation mode | |
| except Exception as e: | |
| raise RuntimeError(f"Error loading model: {str(e)}") | |
| return model, tokenizer | |
| # Load model on app startup | |
| async def lifespan(app: FastAPI): | |
| global model, tokenizer | |
| model, tokenizer = load_model() | |
| yield | |
| # Attach startup loader | |
| app = FastAPI(lifespan=lifespan) | |
| # Input schema | |
| class TextInput(BaseModel): | |
| text: str | |
| # Sync text classification | |
| def classify_text(sentence: str): | |
| inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True) | |
| input_ids = inputs["input_ids"] | |
| attention_mask = inputs["attention_mask"] | |
| with torch.no_grad(): | |
| outputs = model(input_ids, attention_mask=attention_mask, labels=input_ids) | |
| loss = outputs.loss | |
| perplexity = torch.exp(loss).item() | |
| if perplexity < 60: | |
| result = "AI-generated" | |
| elif perplexity < 80: | |
| result = "Probably AI-generated" | |
| else: | |
| result = "Human-written" | |
| return result, perplexity | |
| # POST route to analyze text | |
| async def analyze_text(data: TextInput): | |
| user_input = data.text.strip() | |
| if not user_input: | |
| raise HTTPException(status_code=400, detail="Text cannot be empty") | |
| # Run classification asynchronously to prevent blocking | |
| result, perplexity = await asyncio.to_thread(classify_text, user_input) | |
| return { | |
| "result": result, | |
| "perplexity": round(perplexity, 2), | |
| } | |
| # Health check route | |
| async def health_check(): | |
| return {"status": "ok"} | |
| # Simple index route | |
| def index(): | |
| return { | |
| "message": "FastAPI API is up.", | |
| "try": "/docs to test the API.", | |
| "status": "OK" | |
| } | |