Pujan-Dev commited on
Commit
ab7cc71
ยท
verified ยท
1 Parent(s): 351c13d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -38
app.py CHANGED
@@ -1,57 +1,47 @@
1
  import torch
2
  from transformers import GPT2LMHeadModel, GPT2TokenizerFast, GPT2Config
3
- from fastapi import FastAPI, HTTPException, Request
4
  from pydantic import BaseModel
5
  from contextlib import asynccontextmanager
6
  import asyncio
7
 
8
- from slowapi import Limiter
9
- from slowapi.util import get_remote_address
10
- from slowapi.errors import RateLimitExceeded
11
- from fastapi.responses import JSONResponse
12
 
13
- # ๐ŸŒ Rate limiter for abuse prevention
14
- limiter = Limiter(key_func=get_remote_address)
15
 
16
- # ๐Ÿš€ FastAPI app instance
17
- app = FastAPI(lifespan=lambda app: load_lifespan(app))
18
- app.state.limiter = limiter
19
-
20
- # ๐Ÿ“ฆ Global model and tokenizer
21
- model = None
22
- tokenizer = None
23
-
24
- # ๐Ÿง  Optimize CPU usage (only 1 thread for free tier)
25
- torch.set_num_threads(1)
26
-
27
- # ๐Ÿ“ฆ Load model/tokenizer once
28
  def load_model():
29
  model_path = "./Ai-Text-Detector/model"
30
  weights_path = "./Ai-Text-Detector/model_weights.pth"
31
 
32
  try:
33
- tok = GPT2TokenizerFast.from_pretrained(model_path)
34
  config = GPT2Config.from_pretrained(model_path)
35
- mdl = GPT2LMHeadModel(config)
36
- mdl.load_state_dict(torch.load(weights_path, map_location=torch.device("cpu")))
37
- mdl.eval()
38
  except Exception as e:
39
  raise RuntimeError(f"Error loading model: {str(e)}")
40
 
41
- return mdl, tok
42
 
43
- # โš™๏ธ Load during app lifespan
44
  @asynccontextmanager
45
- async def load_lifespan(app: FastAPI):
46
  global model, tokenizer
47
  model, tokenizer = load_model()
48
  yield
49
 
50
- # ๐Ÿ“˜ Input schema
 
 
 
51
  class TextInput(BaseModel):
52
  text: str
53
 
54
- # ๐Ÿš€ Inference function
55
  def classify_text(sentence: str):
56
  inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True)
57
  input_ids = inputs["input_ids"]
@@ -71,28 +61,27 @@ def classify_text(sentence: str):
71
 
72
  return result, perplexity
73
 
74
- # ๐Ÿ›ก๏ธ Rate limit error handler
75
- @app.exception_handler(RateLimitExceeded)
76
- async def rate_limit_handler(request: Request, exc):
77
- return JSONResponse(status_code=429, content={"detail": "Too many requests. Please slow down."})
78
-
79
- # ๐Ÿ” Inference endpoint with rate limiting
80
  @app.post("/analyze")
81
- @limiter.limit("2/second")
82
  async def analyze_text(data: TextInput):
83
  user_input = data.text.strip()
84
  if not user_input:
85
  raise HTTPException(status_code=400, detail="Text cannot be empty")
86
 
 
87
  result, perplexity = await asyncio.to_thread(classify_text, user_input)
88
- return {"result": result, "perplexity": round(perplexity, 2)}
 
 
 
 
89
 
90
- # โœ… Health check
91
  @app.get("/health")
92
  async def health_check():
93
  return {"status": "ok"}
94
 
95
- # โ„น๏ธ Home
96
  @app.get("/")
97
  def index():
98
  return {
 
1
  import torch
2
  from transformers import GPT2LMHeadModel, GPT2TokenizerFast, GPT2Config
3
+ from fastapi import FastAPI, HTTPException
4
  from pydantic import BaseModel
5
  from contextlib import asynccontextmanager
6
  import asyncio
7
 
8
+ # FastAPI app instance
9
+ app = FastAPI()
 
 
10
 
11
+ # Global model and tokenizer variables
12
+ model, tokenizer = None, None
13
 
14
+ # Function to load model and tokenizer
 
 
 
 
 
 
 
 
 
 
 
15
  def load_model():
16
  model_path = "./Ai-Text-Detector/model"
17
  weights_path = "./Ai-Text-Detector/model_weights.pth"
18
 
19
  try:
20
+ tokenizer = GPT2TokenizerFast.from_pretrained(model_path)
21
  config = GPT2Config.from_pretrained(model_path)
22
+ model = GPT2LMHeadModel(config)
23
+ model.load_state_dict(torch.load(weights_path, map_location=torch.device("cpu")))
24
+ model.eval() # Set model to evaluation mode
25
  except Exception as e:
26
  raise RuntimeError(f"Error loading model: {str(e)}")
27
 
28
+ return model, tokenizer
29
 
30
+ # Load model on app startup
31
  @asynccontextmanager
32
+ async def lifespan(app: FastAPI):
33
  global model, tokenizer
34
  model, tokenizer = load_model()
35
  yield
36
 
37
+ # Attach startup loader
38
+ app = FastAPI(lifespan=lifespan)
39
+
40
+ # Input schema
41
  class TextInput(BaseModel):
42
  text: str
43
 
44
+ # Sync text classification
45
  def classify_text(sentence: str):
46
  inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True)
47
  input_ids = inputs["input_ids"]
 
61
 
62
  return result, perplexity
63
 
64
+ # POST route to analyze text
 
 
 
 
 
65
  @app.post("/analyze")
 
66
  async def analyze_text(data: TextInput):
67
  user_input = data.text.strip()
68
  if not user_input:
69
  raise HTTPException(status_code=400, detail="Text cannot be empty")
70
 
71
+ # Run classification asynchronously to prevent blocking
72
  result, perplexity = await asyncio.to_thread(classify_text, user_input)
73
+
74
+ return {
75
+ "result": result,
76
+ "perplexity": round(perplexity, 2),
77
+ }
78
 
79
+ # Health check route
80
  @app.get("/health")
81
  async def health_check():
82
  return {"status": "ok"}
83
 
84
+ # Simple index route
85
  @app.get("/")
86
  def index():
87
  return {