Pujan-Dev commited on
Commit
8adf4f2
·
verified ·
1 Parent(s): 378ec35

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -59
app.py CHANGED
@@ -1,71 +1,43 @@
1
  import torch
2
  from transformers import GPT2LMHeadModel, GPT2TokenizerFast, GPT2Config
3
- from fastapi import FastAPI, HTTPException, Header
4
  from pydantic import BaseModel
5
- import asyncio
6
- from concurrent.futures import ThreadPoolExecutor
7
  from contextlib import asynccontextmanager
8
- from dotenv import dotenv_values
9
 
10
- # FastAPI instance
11
  app = FastAPI()
12
- executor = ThreadPoolExecutor(max_workers=20)
13
 
14
- # Load .env file
15
- env = dotenv_values(".env")
16
- EXPECTED_TOKEN = env.get("SECRET_TOKEN")
17
-
18
- # Global variables for model and tokenizer
19
  model, tokenizer = None, None
20
 
21
- # Function to verify token
22
-
23
-
24
- def verify_token(auth: str):
25
- if auth != f"Bearer {EXPECTED_TOKEN}":
26
- raise HTTPException(status_code=403, detail="Unauthorized")
27
-
28
-
29
  # Function to load model and tokenizer
30
-
31
  def load_model():
32
  model_path = "./Ai-Text-Detector/model"
33
  weights_path = "./Ai-Text-Detector/model_weights.pth"
34
 
35
- # Load tokenizer and config from your custom model path
36
  tokenizer = GPT2TokenizerFast.from_pretrained(model_path)
37
  config = GPT2Config.from_pretrained(model_path)
38
-
39
- # Initialize model from config (don't load any weights from Hugging Face)
40
  model = GPT2LMHeadModel(config)
41
-
42
- # Load your saved PyTorch weights
43
  model.load_state_dict(torch.load(weights_path, map_location=torch.device("cpu")))
44
-
45
- model.eval() # Set to evaluation mode
46
  return model, tokenizer
47
 
 
48
  @asynccontextmanager
49
  async def lifespan(app: FastAPI):
50
  global model, tokenizer
51
  model, tokenizer = load_model()
52
  yield
53
 
54
-
55
- # Attach the lifespan context manager
56
  app = FastAPI(lifespan=lifespan)
57
 
58
- # Request body for input data
59
-
60
-
61
  class TextInput(BaseModel):
62
  text: str
63
 
64
-
65
- # Sync function to classify text
66
-
67
-
68
- def classify_text_sync(sentence: str):
69
  inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True)
70
  input_ids = inputs["input_ids"]
71
  attention_mask = inputs["attention_mask"]
@@ -84,44 +56,25 @@ def classify_text_sync(sentence: str):
84
 
85
  return result, perplexity
86
 
87
-
88
- # Async wrapper for text classification
89
-
90
-
91
- async def classify_text(sentence: str):
92
- loop = asyncio.get_event_loop()
93
- return await loop.run_in_executor(executor, classify_text_sync, sentence)
94
-
95
-
96
  # POST route to analyze text
97
-
98
-
99
  @app.post("/analyze")
100
- async def analyze_text(data: TextInput, authorization: str = Header(default="")):
101
- verify_token(authorization) # Token verification
102
  user_input = data.text.strip()
103
-
104
  if not user_input:
105
  raise HTTPException(status_code=400, detail="Text cannot be empty")
106
 
107
- result, perplexity = await classify_text(user_input)
108
-
109
  return {
110
  "result": result,
111
  "perplexity": round(perplexity, 2),
112
  }
113
 
114
-
115
  # Health check route
116
-
117
-
118
  @app.get("/health")
119
  async def health_check():
120
  return {"status": "ok"}
121
 
122
-
123
  # Simple index route
124
-
125
  @app.get("/")
126
  def index():
127
  return {
@@ -129,4 +82,3 @@ def index():
129
  "try": "/docs to test the API.",
130
  "status": "OK"
131
  }
132
-
 
1
  import torch
2
  from transformers import GPT2LMHeadModel, GPT2TokenizerFast, GPT2Config
3
+ from fastapi import FastAPI, HTTPException
4
  from pydantic import BaseModel
 
 
5
  from contextlib import asynccontextmanager
 
6
 
7
+ # FastAPI app instance
8
  app = FastAPI()
 
9
 
10
+ # Global model and tokenizer
 
 
 
 
11
  model, tokenizer = None, None
12
 
 
 
 
 
 
 
 
 
13
  # Function to load model and tokenizer
 
14
  def load_model():
15
  model_path = "./Ai-Text-Detector/model"
16
  weights_path = "./Ai-Text-Detector/model_weights.pth"
17
 
 
18
  tokenizer = GPT2TokenizerFast.from_pretrained(model_path)
19
  config = GPT2Config.from_pretrained(model_path)
 
 
20
  model = GPT2LMHeadModel(config)
 
 
21
  model.load_state_dict(torch.load(weights_path, map_location=torch.device("cpu")))
22
+ model.eval()
 
23
  return model, tokenizer
24
 
25
+ # Load on app startup
26
  @asynccontextmanager
27
  async def lifespan(app: FastAPI):
28
  global model, tokenizer
29
  model, tokenizer = load_model()
30
  yield
31
 
32
+ # Attach startup loader
 
33
  app = FastAPI(lifespan=lifespan)
34
 
35
+ # Input schema
 
 
36
  class TextInput(BaseModel):
37
  text: str
38
 
39
+ # Sync text classification
40
+ def classify_text(sentence: str):
 
 
 
41
  inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True)
42
  input_ids = inputs["input_ids"]
43
  attention_mask = inputs["attention_mask"]
 
56
 
57
  return result, perplexity
58
 
 
 
 
 
 
 
 
 
 
59
  # POST route to analyze text
 
 
60
  @app.post("/analyze")
61
+ async def analyze_text(data: TextInput):
 
62
  user_input = data.text.strip()
 
63
  if not user_input:
64
  raise HTTPException(status_code=400, detail="Text cannot be empty")
65
 
66
+ result, perplexity = classify_text(user_input)
 
67
  return {
68
  "result": result,
69
  "perplexity": round(perplexity, 2),
70
  }
71
 
 
72
  # Health check route
 
 
73
  @app.get("/health")
74
  async def health_check():
75
  return {"status": "ok"}
76
 
 
77
  # Simple index route
 
78
  @app.get("/")
79
  def index():
80
  return {
 
82
  "try": "/docs to test the API.",
83
  "status": "OK"
84
  }