Spaces:
Sleeping
Sleeping
| import torch | |
| from transformers import GPT2LMHeadModel, GPT2TokenizerFast | |
| from fastapi import FastAPI, HTTPException, Header | |
| from pydantic import BaseModel | |
| import asyncio | |
| from concurrent.futures import ThreadPoolExecutor | |
| from contextlib import asynccontextmanager | |
| from dotenv import dotenv_values | |
| # FastAPI instance | |
| app = FastAPI() | |
| executor = ThreadPoolExecutor(max_workers=20) | |
| # Load .env file | |
| env = dotenv_values(".env") | |
| EXPECTED_TOKEN = env.get("SECRET_TOKEN") | |
| # Global variables for model and tokenizer | |
| model, tokenizer = None, None | |
| # Function to verify token | |
| def verify_token(auth: str): | |
| if auth != f"Bearer {EXPECTED_TOKEN}": | |
| raise HTTPException(status_code=403, detail="Unauthorized") | |
| # Function to load model and tokenizer | |
| def load_model(): | |
| model_path = "./Ai-Text-Detector/model" | |
| weights_path = "./Ai-Text-Detector/model_weights.pth" | |
| tokenizer = GPT2TokenizerFast.from_pretrained(model_path) | |
| model = GPT2LMHeadModel.from_pretrained("gpt2") | |
| model.load_state_dict(torch.load(weights_path, map_location=torch.device("cpu"))) | |
| model.eval() # Set the model to evaluation mode | |
| return model, tokenizer | |
| async def lifespan(app: FastAPI): | |
| global model, tokenizer | |
| model, tokenizer = load_model() | |
| yield | |
| # Attach the lifespan context manager | |
| app = FastAPI(lifespan=lifespan) | |
| # Request body for input data | |
| class TextInput(BaseModel): | |
| text: str | |
| # Sync function to classify text | |
| def classify_text_sync(sentence: str): | |
| inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True) | |
| input_ids = inputs["input_ids"] | |
| attention_mask = inputs["attention_mask"] | |
| with torch.no_grad(): | |
| outputs = model(input_ids, attention_mask=attention_mask, labels=input_ids) | |
| loss = outputs.loss | |
| perplexity = torch.exp(loss).item() | |
| if perplexity < 60: | |
| result = "AI-generated*" | |
| elif perplexity < 80: | |
| result = "Probably AI-generated*" | |
| else: | |
| result = "Human-written*" | |
| return result, perplexity | |
| # Async wrapper for text classification | |
| async def classify_text(sentence: str): | |
| loop = asyncio.get_event_loop() | |
| return await loop.run_in_executor(executor, classify_text_sync, sentence) | |
| # POST route to analyze text | |
| async def analyze_text(data: TextInput, authorization: str = Header(default="")): | |
| verify_token(authorization) # Token verification | |
| user_input = data.text.strip() | |
| if not user_input: | |
| raise HTTPException(status_code=400, detail="Text cannot be empty") | |
| result, perplexity = await classify_text(user_input) | |
| return { | |
| "result": result, | |
| "perplexity": round(perplexity, 2), | |
| } | |
| # Health check route | |
| async def health_check(): | |
| return {"status": "ok"} | |
| # Simple index route | |
| def index(): | |
| return {"message": "It's an API"} | |
| # Start the app (run with uvicorn) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run("main:app", host="0.0.0.0", port=8000, workers=4) | |