Pujan-Dev commited on
Commit
1548e30
·
verified ·
1 Parent(s): 99d3d8f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +124 -60
app.py CHANGED
@@ -1,55 +1,80 @@
1
- from fastapi import FastAPI, HTTPException, Depends
2
  from fastapi.security import HTTPBearer
3
  from pydantic import BaseModel
4
  from transformers import GPT2LMHeadModel, GPT2TokenizerFast, GPT2Config
5
- import torch
6
- import asyncio
7
  from contextlib import asynccontextmanager
8
 
9
- # FastAPI app instance
10
- app = FastAPI()
11
-
12
- # Global model and tokenizer variables
13
- model, tokenizer = None, None
14
-
15
- # HTTPBearer instance for security
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  bearer_scheme = HTTPBearer()
17
 
18
- # Function to load model and tokenizer
19
- def load_model():
20
- model_path = "./Ai-Text-Detector/model"
21
- weights_path = "./Ai-Text-Detector/model_weights.pth"
22
 
 
 
 
23
  try:
24
- tokenizer = GPT2TokenizerFast.from_pretrained(model_path)
25
- config = GPT2Config.from_pretrained(model_path)
26
- model = GPT2LMHeadModel(config)
27
- model.load_state_dict(torch.load(weights_path, map_location=torch.device("cpu")))
28
- model.eval() # Set model to evaluation mode
 
 
29
  except Exception as e:
30
  raise RuntimeError(f"Error loading model: {str(e)}")
31
 
32
- return model, tokenizer
33
-
34
- # Load model on app startup
35
  @asynccontextmanager
36
  async def lifespan(app: FastAPI):
37
- global model, tokenizer
38
- model, tokenizer = load_model()
39
  yield
40
 
41
- # Attach startup loader
42
  app = FastAPI(lifespan=lifespan)
43
 
44
- # Input schema
45
- class TextInput(BaseModel):
46
- text: str
 
47
 
48
- # Sync text classification
49
- def classify_text(sentence: str):
50
- inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True)
51
- input_ids = inputs["input_ids"]
52
- attention_mask = inputs["attention_mask"]
 
 
 
 
 
53
 
54
  with torch.no_grad():
55
  outputs = model(input_ids, attention_mask=attention_mask, labels=input_ids)
@@ -57,49 +82,88 @@ def classify_text(sentence: str):
57
  perplexity = torch.exp(loss).item()
58
 
59
  if perplexity < 60:
60
- result = "AI-generated"
61
  elif perplexity < 80:
62
- result = "Probably AI-generated"
63
  else:
64
- result = "Human-written"
65
 
66
- return result, perplexity
 
 
67
 
68
- # POST route to analyze text with Bearer token
69
  @app.post("/analyze")
70
  async def analyze_text(data: TextInput, token: str = Depends(bearer_scheme)):
71
- user_input = data.text.strip()
 
72
 
73
- if not user_input:
 
74
  raise HTTPException(status_code=400, detail="Text cannot be empty")
75
 
76
- # Check if there are at least two words
77
- word_count = len(user_input.split())
78
- if word_count < 2:
79
  raise HTTPException(status_code=400, detail="Text must contain at least two words")
80
-
81
- # The token is automatically extracted from the Authorization header
82
- # You can validate the token here if needed
83
- print(f"Received Bearer Token: {token}")
84
 
85
- # Run classification asynchronously to prevent blocking
86
- result, perplexity = await asyncio.to_thread(classify_text, user_input)
87
-
88
- return {
89
- "result": result,
90
- "perplexity": round(perplexity, 2),
91
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
- # Health check route
94
  @app.get("/health")
95
- async def health_check():
96
  return {"status": "ok"}
97
 
98
- # Simple index route
99
  @app.get("/")
100
  def index():
101
  return {
102
- "message": "FastAPI API is up.",
103
- "try": "/docs to test the API.",
104
- "status": "OK"
105
  }
 
1
+ from fastapi import FastAPI, HTTPException, Depends, UploadFile, File
2
  from fastapi.security import HTTPBearer
3
  from pydantic import BaseModel
4
  from transformers import GPT2LMHeadModel, GPT2TokenizerFast, GPT2Config
5
+ from dotenv import load_dotenv
 
6
  from contextlib import asynccontextmanager
7
 
8
+ import torch
9
+ import asyncio
10
+ import math
11
+ import os
12
+ import docx
13
+ import fitz # PyMuPDF
14
+ import logging
15
+ from io import BytesIO
16
+
17
+ # Setup logging
18
+ logging.basicConfig(level=logging.DEBUG)
19
+
20
+ # Load environment variables
21
+ load_dotenv()
22
+ SECRET_TOKEN = os.getenv("SECRET_TOKEN")
23
+
24
+ # File Paths
25
+ MODEL_PATH = "./AI-MODEL/model"
26
+ WEIGHTS_PATH = "./AI-MODEL/model_weights.pth"
27
+
28
+ # Global model and tokenizer
29
+ model = None
30
+ tokenizer = None
31
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
+
33
+ # Security
34
  bearer_scheme = HTTPBearer()
35
 
36
+ # Text input schema
37
+ class TextInput(BaseModel):
38
+ text: str
 
39
 
40
+ # Load model and tokenizer
41
+ def load_model():
42
+ global model, tokenizer
43
  try:
44
+ tokenizer = GPT2TokenizerFast.from_pretrained(MODEL_PATH)
45
+ config = GPT2Config.from_pretrained(MODEL_PATH)
46
+ model_instance = GPT2LMHeadModel(config)
47
+ model_instance.load_state_dict(torch.load(WEIGHTS_PATH, map_location=device))
48
+ model_instance.to(device)
49
+ model_instance.eval()
50
+ model = model_instance
51
  except Exception as e:
52
  raise RuntimeError(f"Error loading model: {str(e)}")
53
 
54
+ # Lifespan event to load model on startup
 
 
55
  @asynccontextmanager
56
  async def lifespan(app: FastAPI):
57
+ load_model()
 
58
  yield
59
 
60
+ # FastAPI app instance
61
  app = FastAPI(lifespan=lifespan)
62
 
63
+ # Classification logic
64
+ def classify_text(text: str):
65
+ if not model or not tokenizer:
66
+ raise RuntimeError("Model or tokenizer not loaded.")
67
 
68
+ inputs = tokenizer(
69
+ text,
70
+ return_tensors="pt",
71
+ truncation=True,
72
+ padding="max_length",
73
+ max_length=512
74
+ )
75
+
76
+ input_ids = inputs["input_ids"].to(device)
77
+ attention_mask = inputs["attention_mask"].to(device)
78
 
79
  with torch.no_grad():
80
  outputs = model(input_ids, attention_mask=attention_mask, labels=input_ids)
 
82
  perplexity = torch.exp(loss).item()
83
 
84
  if perplexity < 60:
85
+ return "AI-generated", perplexity
86
  elif perplexity < 80:
87
+ return "Probably AI-generated", perplexity
88
  else:
89
+ return "Human-written", perplexity
90
 
91
+ # Score converter (optional utility)
92
+ def Perplexity_Converter(perplexity):
93
+ return max(0, min(100, 100 - math.log2(perplexity) * 10))
94
 
95
+ # Analyze text directly
96
  @app.post("/analyze")
97
  async def analyze_text(data: TextInput, token: str = Depends(bearer_scheme)):
98
+ if token.credentials != SECRET_TOKEN:
99
+ raise HTTPException(status_code=401, detail="Invalid token")
100
 
101
+ text = data.text.strip()
102
+ if not text:
103
  raise HTTPException(status_code=400, detail="Text cannot be empty")
104
 
105
+ if len(text.split()) < 2:
 
 
106
  raise HTTPException(status_code=400, detail="Text must contain at least two words")
 
 
 
 
107
 
108
+ try:
109
+ label, perplexity = await asyncio.to_thread(classify_text, text)
110
+ return {"result": label, "perplexity": round(perplexity, 2)}
111
+ except Exception as e:
112
+ logging.error(f"Text analysis failed: {str(e)}")
113
+ raise HTTPException(status_code=500, detail="Model processing error")
114
+
115
+ # -------- File Upload and Parsing -------- #
116
+ def parse_docx(file: BytesIO):
117
+ doc = docx.Document(file)
118
+ return "\n".join(para.text for para in doc.paragraphs)
119
+
120
+ def parse_pdf(file: BytesIO):
121
+ try:
122
+ doc = fitz.open(stream=file, filetype="pdf")
123
+ return "".join([doc.load_page(i).get_text() for i in range(doc.page_count)])
124
+ except Exception as e:
125
+ logging.error(f"PDF error: {str(e)}")
126
+ raise HTTPException(status_code=500, detail="Error processing PDF")
127
+
128
+ def parse_txt(file: BytesIO):
129
+ return file.read().decode("utf-8")
130
+
131
+ @app.post("/upload/")
132
+ async def upload_file(file: UploadFile = File(...), token: str = Depends(bearer_scheme)):
133
+ if token.credentials != SECRET_TOKEN:
134
+ raise HTTPException(status_code=401, detail="Invalid token")
135
+
136
+ try:
137
+ content_type = file.content_type
138
+ content = await file.read()
139
+ if content_type == 'application/vnd.openxmlformats-officedocument.wordprocessingml.document':
140
+ text = parse_docx(BytesIO(content))
141
+ elif content_type == 'application/pdf':
142
+ text = parse_pdf(BytesIO(content))
143
+ elif content_type == 'text/plain':
144
+ text = parse_txt(BytesIO(content))
145
+ else:
146
+ raise HTTPException(status_code=400, detail="Unsupported file type")
147
+
148
+ if len(text) > 10000:
149
+ return {"message": "File contains more than 10,000 characters."}
150
+
151
+ cleaned_text = text.replace("\n", "").replace("\t", "")
152
+ label, perplexity = await asyncio.to_thread(classify_text, cleaned_text)
153
+ return {"result": label, "perplexity": round(perplexity, 2)}
154
+
155
+ except Exception as e:
156
+ logging.error(f"File processing error: {str(e)}")
157
+ raise HTTPException(status_code=500, detail="Error processing file")
158
 
159
+ # Health Check and Index
160
  @app.get("/health")
161
+ def health_check():
162
  return {"status": "ok"}
163
 
 
164
  @app.get("/")
165
  def index():
166
  return {
167
+ "message": "FastAPI AI Text Detector is running.",
168
+ "usage": "Use /docs or /analyze or /upload to test the API."
 
169
  }