Pujan Neupane commited on
Commit
bd8b847
·
unverified ·
2 Parent(s): 9d3728d5c92764

Merge pull request #4 from cyberalertnepal/Pujan

Browse files

Fix server crashes and high perplexity issue, refactor app.py, and update documentation

Files changed (1) hide show
  1. app.py +22 -33
app.py CHANGED
@@ -1,9 +1,10 @@
1
- import torch
2
- from transformers import GPT2LMHeadModel, GPT2TokenizerFast, GPT2Config
3
- from fastapi import FastAPI, HTTPException
4
  from pydantic import BaseModel
5
- from contextlib import asynccontextmanager
 
6
  import asyncio
 
7
 
8
  # FastAPI app instance
9
  app = FastAPI()
@@ -11,9 +12,10 @@ app = FastAPI()
11
  # Global model and tokenizer variables
12
  model, tokenizer = None, None
13
 
14
- # Function to load model and tokenizer
15
-
16
 
 
17
  def load_model():
18
  model_path = "./Ai-Text-Detector/model"
19
  weights_path = "./Ai-Text-Detector/model_weights.pth"
@@ -22,39 +24,28 @@ def load_model():
22
  tokenizer = GPT2TokenizerFast.from_pretrained(model_path)
23
  config = GPT2Config.from_pretrained(model_path)
24
  model = GPT2LMHeadModel(config)
25
- model.load_state_dict(
26
- torch.load(weights_path, map_location=torch.device("cpu"))
27
- )
28
- model.eval() # Set model to evaluation mode
29
  except Exception as e:
30
  raise RuntimeError(f"Error loading model: {str(e)}")
31
 
32
  return model, tokenizer
33
 
34
-
35
  # Load model on app startup
36
-
37
-
38
  @asynccontextmanager
39
  async def lifespan(app: FastAPI):
40
  global model, tokenizer
41
  model, tokenizer = load_model()
42
  yield
43
 
44
-
45
  # Attach startup loader
46
  app = FastAPI(lifespan=lifespan)
47
 
48
  # Input schema
49
-
50
-
51
  class TextInput(BaseModel):
52
  text: str
53
 
54
-
55
  # Sync text classification
56
-
57
-
58
  def classify_text(sentence: str):
59
  inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True)
60
  input_ids = inputs["input_ids"]
@@ -74,40 +65,38 @@ def classify_text(sentence: str):
74
 
75
  return result, perplexity
76
 
77
-
78
- # POST route to analyze text
79
-
80
-
81
  @app.post("/analyze")
82
- async def analyze_text(data: TextInput):
83
  user_input = data.text.strip()
 
84
  if not user_input:
85
  raise HTTPException(status_code=400, detail="Text cannot be empty")
 
 
 
 
 
86
 
87
- # Run classification asynchronously to prevent blocking
88
- result, perplexity = await asyncio.to_thread(classify_text, user_input)
89
 
 
 
90
  return {
91
  "result": result,
92
  "perplexity": round(perplexity, 2),
93
  }
94
 
95
-
96
  # Health check route
97
-
98
-
99
  @app.get("/health")
100
  async def health_check():
101
  return {"status": "ok"}
102
 
103
-
104
  # Simple index route
105
-
106
-
107
  @app.get("/")
108
  def index():
109
  return {
110
  "message": "FastAPI API is up.",
111
  "try": "/docs to test the API.",
112
- "status": "OK",
113
  }
 
 
1
+ from fastapi import FastAPI, HTTPException, Depends
2
+ from fastapi.security import HTTPBearer
 
3
  from pydantic import BaseModel
4
+ from transformers import GPT2LMHeadModel, GPT2TokenizerFast, GPT2Config
5
+ import torch
6
  import asyncio
7
+ from contextlib import asynccontextmanager
8
 
9
  # FastAPI app instance
10
  app = FastAPI()
 
12
  # Global model and tokenizer variables
13
  model, tokenizer = None, None
14
 
15
+ # HTTPBearer instance for security
16
+ bearer_scheme = HTTPBearer()
17
 
18
+ # Function to load model and tokenizer
19
  def load_model():
20
  model_path = "./Ai-Text-Detector/model"
21
  weights_path = "./Ai-Text-Detector/model_weights.pth"
 
24
  tokenizer = GPT2TokenizerFast.from_pretrained(model_path)
25
  config = GPT2Config.from_pretrained(model_path)
26
  model = GPT2LMHeadModel(config)
27
+ model.load_state_dict(torch.load(weights_path, map_location=torch.device("cpu")))
28
+ model.eval()
 
 
29
  except Exception as e:
30
  raise RuntimeError(f"Error loading model: {str(e)}")
31
 
32
  return model, tokenizer
33
 
 
34
  # Load model on app startup
 
 
35
  @asynccontextmanager
36
  async def lifespan(app: FastAPI):
37
  global model, tokenizer
38
  model, tokenizer = load_model()
39
  yield
40
 
 
41
  # Attach startup loader
42
  app = FastAPI(lifespan=lifespan)
43
 
44
  # Input schema
 
 
45
  class TextInput(BaseModel):
46
  text: str
47
 
 
48
  # Sync text classification
 
 
49
  def classify_text(sentence: str):
50
  inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True)
51
  input_ids = inputs["input_ids"]
 
65
 
66
  return result, perplexity
67
 
68
+ # POST route to analyze text with Bearer token
 
 
 
69
  @app.post("/analyze")
70
+ async def analyze_text(data: TextInput, token: str = Depends(bearer_scheme)):
71
  user_input = data.text.strip()
72
+
73
  if not user_input:
74
  raise HTTPException(status_code=400, detail="Text cannot be empty")
75
+
76
+ # Check if there are at least two words
77
+ word_count = len(user_input.split())
78
+ if word_count < 2:
79
+ raise HTTPException(status_code=400, detail="Text must contain at least two words")
80
 
 
 
81
 
82
+ result, perplexity = await asyncio.to_thread(classify_text, user_input)
83
+
84
  return {
85
  "result": result,
86
  "perplexity": round(perplexity, 2),
87
  }
88
 
 
89
  # Health check route
 
 
90
  @app.get("/health")
91
  async def health_check():
92
  return {"status": "ok"}
93
 
 
94
  # Simple index route
 
 
95
  @app.get("/")
96
  def index():
97
  return {
98
  "message": "FastAPI API is up.",
99
  "try": "/docs to test the API.",
100
+ "status": "OK"
101
  }
102
+