Pujan-Dev commited on
Commit
52fdab2
·
verified ·
1 Parent(s): 2b290c6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -41
app.py CHANGED
@@ -1,9 +1,10 @@
1
- import torch
2
- from transformers import GPT2LMHeadModel, GPT2TokenizerFast, GPT2Config
3
- from fastapi import FastAPI, HTTPException
4
  from pydantic import BaseModel
5
- from contextlib import asynccontextmanager
 
6
  import asyncio
 
7
 
8
  # FastAPI app instance
9
  app = FastAPI()
@@ -11,6 +12,9 @@ app = FastAPI()
11
  # Global model and tokenizer variables
12
  model, tokenizer = None, None
13
 
 
 
 
14
  # Function to load model and tokenizer
15
  def load_model():
16
  model_path = "./Ai-Text-Detector/model"
@@ -31,12 +35,8 @@ def load_model():
31
  @asynccontextmanager
32
  async def lifespan(app: FastAPI):
33
  global model, tokenizer
34
- try:
35
- model, tokenizer = load_model()
36
- yield
37
- except Exception as e:
38
- print(f"Startup error: {str(e)}")
39
- raise RuntimeError(f"Failed to start application: {str(e)}")
40
 
41
  # Attach startup loader
42
  app = FastAPI(lifespan=lifespan)
@@ -47,42 +47,42 @@ class TextInput(BaseModel):
47
 
48
  # Sync text classification
49
  def classify_text(sentence: str):
50
- try:
51
- inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True)
52
- input_ids = inputs["input_ids"]
53
- attention_mask = inputs["attention_mask"]
54
-
55
- with torch.no_grad():
56
- outputs = model(input_ids, attention_mask=attention_mask, labels=input_ids)
57
- loss = outputs.loss
58
- perplexity = torch.exp(loss).item()
59
-
60
- if perplexity < 60:
61
- result = "AI-generated"
62
- elif perplexity < 80:
63
- result = "Probably AI-generated"
64
- else:
65
- result = "Human-written"
66
-
67
- return result, perplexity
68
- except Exception as e:
69
- raise RuntimeError(f"Error during text classification: {str(e)}")
70
 
71
- # POST route to analyze text
72
  @app.post("/analyze")
73
- async def analyze_text(data: TextInput):
74
  user_input = data.text.strip()
75
  if not user_input:
76
  raise HTTPException(status_code=400, detail="Text cannot be empty")
77
 
78
- try:
79
- result, perplexity = await asyncio.to_thread(classify_text, user_input)
80
- return {
81
- "result": result,
82
- "perplexity": round(perplexity, 2),
83
- }
84
- except Exception as e:
85
- raise HTTPException(status_code=500, detail=f"Analysis failed: {str(e)}")
 
 
 
86
 
87
  # Health check route
88
  @app.get("/health")
@@ -96,4 +96,4 @@ def index():
96
  "message": "FastAPI API is up.",
97
  "try": "/docs to test the API.",
98
  "status": "OK"
99
- }
 
1
+ from fastapi import FastAPI, HTTPException, Depends, Security
2
+ from fastapi.security import HTTPBearer
 
3
  from pydantic import BaseModel
4
+ from transformers import GPT2LMHeadModel, GPT2TokenizerFast, GPT2Config
5
+ import torch
6
  import asyncio
7
+ from contextlib import asynccontextmanager
8
 
9
  # FastAPI app instance
10
  app = FastAPI()
 
12
  # Global model and tokenizer variables
13
  model, tokenizer = None, None
14
 
15
+ # HTTPBearer instance for security
16
+ bearer_scheme = HTTPBearer()
17
+
18
  # Function to load model and tokenizer
19
  def load_model():
20
  model_path = "./Ai-Text-Detector/model"
 
35
  @asynccontextmanager
36
  async def lifespan(app: FastAPI):
37
  global model, tokenizer
38
+ model, tokenizer = load_model()
39
+ yield
 
 
 
 
40
 
41
  # Attach startup loader
42
  app = FastAPI(lifespan=lifespan)
 
47
 
48
  # Sync text classification
49
  def classify_text(sentence: str):
50
+ inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True)
51
+ input_ids = inputs["input_ids"]
52
+ attention_mask = inputs["attention_mask"]
53
+
54
+ with torch.no_grad():
55
+ outputs = model(input_ids, attention_mask=attention_mask, labels=input_ids)
56
+ loss = outputs.loss
57
+ perplexity = torch.exp(loss).item()
58
+
59
+ if perplexity < 60:
60
+ result = "AI-generated"
61
+ elif perplexity < 80:
62
+ result = "Probably AI-generated"
63
+ else:
64
+ result = "Human-written"
65
+
66
+ return result, perplexity
 
 
 
67
 
68
+ # POST route to analyze text with Bearer token
69
  @app.post("/analyze")
70
+ async def analyze_text(data: TextInput, token: str = Depends(bearer_scheme)):
71
  user_input = data.text.strip()
72
  if not user_input:
73
  raise HTTPException(status_code=400, detail="Text cannot be empty")
74
 
75
+ # The token is automatically extracted from the Authorization header
76
+ # You can validate the token here if needed
77
+ print(f"Received Bearer Token: {token}")
78
+
79
+ # Run classification asynchronously to prevent blocking
80
+ result, perplexity = await asyncio.to_thread(classify_text, user_input)
81
+
82
+ return {
83
+ "result": result,
84
+ "perplexity": round(perplexity, 2),
85
+ }
86
 
87
  # Health check route
88
  @app.get("/health")
 
96
  "message": "FastAPI API is up.",
97
  "try": "/docs to test the API.",
98
  "status": "OK"
99
+ }