from fastapi import FastAPI, HTTPException, Depends, UploadFile, File from fastapi.security import HTTPBearer from pydantic import BaseModel from transformers import GPT2LMHeadModel, GPT2TokenizerFast, GPT2Config import torch import os import asyncio from contextlib import asynccontextmanager import logging from io import BytesIO import docx import fitz # PyMuPDF # Load environment variables from dotenv import load_dotenv load_dotenv() SECRET_TOKEN = os.getenv("SECRET_TOKEN") bearer_scheme = HTTPBearer() # Ai-Text-Detector MODEL_PATH = "./Ai-Text-Detector/model" WEIGHTS_PATH = "./Ai-Text-Detector/model_weights.pth" # FastAPI app instance app = FastAPI() # Global model and tokenizer variables model, tokenizer = None, None device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Logging setup logging.basicConfig(level=logging.DEBUG) # Load model and tokenizer function def load_model(): global model, tokenizer try: tokenizer = GPT2TokenizerFast.from_pretrained(MODEL_PATH) config = GPT2Config.from_pretrained(MODEL_PATH) model_instance = GPT2LMHeadModel(config) model_instance.load_state_dict(torch.load(WEIGHTS_PATH, map_location=device)) model_instance.to(device) model_instance.eval() model, tokenizer = model_instance, tokenizer logging.info("Model loaded successfully.") except Exception as e: logging.error(f"Error loading model: {str(e)}") raise RuntimeError(f"Error loading model: {str(e)}") # Load model on app startup @asynccontextmanager async def lifespan(app: FastAPI): load_model() # Load model when FastAPI app starts yield # Attach the lifespan to the app instance app = FastAPI(lifespan=lifespan) # Input schema for text analysis class TextInput(BaseModel): text: str # Function to classify text using the model def classify_text(text: str): if not model or not tokenizer: raise RuntimeError("Model or tokenizer not loaded.") inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512) input_ids = inputs["input_ids"].to(device) attention_mask = inputs["attention_mask"].to(device) with torch.no_grad(): outputs = model(input_ids, attention_mask=attention_mask, labels=input_ids) loss = outputs.loss perplexity = torch.exp(loss).item() if perplexity < 60: return "AI-generated", perplexity elif perplexity < 80: return "Probably AI-generated", perplexity else: return "Human-written", perplexity # POST route to analyze text with Bearer token @app.post("/analyze") async def analyze_text(data: TextInput, token: str = Depends(bearer_scheme)): # Verify token if token.credentials != SECRET_TOKEN: raise HTTPException(status_code=401, detail="Invalid token") text = data.text.strip() # Input validation if not text: raise HTTPException(status_code=400, detail="Text cannot be empty") if len(text.split()) < 2: raise HTTPException(status_code=400, detail="Text must contain at least two words") try: # Classify text label, perplexity = await asyncio.to_thread(classify_text, text) return {"result": label, "perplexity": round(perplexity, 2)} except Exception as e: logging.error(f"Error processing text: {str(e)}") raise HTTPException(status_code=500, detail="Model processing error") # Function to parse .docx files def parse_docx(file: BytesIO): doc = docx.Document(file) text = "" for para in doc.paragraphs: text += para.text + "\n" return text # Function to parse .pdf files def parse_pdf(file: BytesIO): try: doc = fitz.open(stream=file, filetype="pdf") text = "" for page_num in range(doc.page_count): page = doc.load_page(page_num) text += page.get_text() return text except Exception as e: logging.error(f"Error while processing PDF: {str(e)}") raise HTTPException(status_code=500, detail="Error processing PDF file") # Function to parse .txt files def parse_txt(file: BytesIO): return file.read().decode("utf-8") # POST route to upload files and analyze content @app.post("/upload/") async def upload_file(file: UploadFile = File(...), token: str = Depends(bearer_scheme)): file_contents = None try: if file.content_type == 'application/vnd.openxmlformats-officedocument.wordprocessingml.document': file_contents = parse_docx(BytesIO(await file.read())) elif file.content_type == 'application/pdf': file_contents = parse_pdf(BytesIO(await file.read())) elif file.content_type == 'text/plain': file_contents = parse_txt(BytesIO(await file.read())) else: raise HTTPException(status_code=400, detail="Invalid file type. Only .docx, .pdf, and .txt are allowed.") logging.debug(f"Extracted Text from {file.filename}:\n{file_contents}") # Check if the text length exceeds 10,000 characters if len(file_contents) > 10000: return {"message": "File contains more than 10,000 characters."} # Clean the text by removing newline and tab characters cleaned_text = file_contents.replace("\n", "").replace("\t", "") # Analyze the cleaned text label, perplexity = await asyncio.to_thread(classify_text, cleaned_text) return {"result": label, "perplexity": round(perplexity, 2)} except Exception as e: logging.error(f"Error processing file: {str(e)}") raise HTTPException(status_code=500, detail="Error processing the file") # Health check route @app.get("/health") async def health_check(): return {"status": "ok"} # Simple index route @app.get("/") def index(): return { "message": "FastAPI AI Text Detector is running.", "usage": "Use /docs or /analyze to test the API." }