Pujan-Dev commited on
Commit
da25d43
·
verified ·
1 Parent(s): 0d58107

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -11
app.py CHANGED
@@ -3,11 +3,12 @@ from transformers import GPT2LMHeadModel, GPT2TokenizerFast, GPT2Config
3
  from fastapi import FastAPI, HTTPException
4
  from pydantic import BaseModel
5
  from contextlib import asynccontextmanager
 
6
 
7
  # FastAPI app instance
8
  app = FastAPI()
9
 
10
- # Global model and tokenizer
11
  model, tokenizer = None, None
12
 
13
  # Function to load model and tokenizer
@@ -15,14 +16,18 @@ def load_model():
15
  model_path = "./Ai-Text-Detector/model"
16
  weights_path = "./Ai-Text-Detector/model_weights.pth"
17
 
18
- tokenizer = GPT2TokenizerFast.from_pretrained(model_path)
19
- config = GPT2Config.from_pretrained(model_path)
20
- model = GPT2LMHeadModel(config)
21
- model.load_state_dict(torch.load(weights_path, map_location=torch.device("cpu")))
22
- model.eval()
 
 
 
 
23
  return model, tokenizer
24
 
25
- # Load on app startup
26
  @asynccontextmanager
27
  async def lifespan(app: FastAPI):
28
  global model, tokenizer
@@ -48,11 +53,11 @@ def classify_text(sentence: str):
48
  perplexity = torch.exp(loss).item()
49
 
50
  if perplexity < 60:
51
- result = "AI-generated*"
52
  elif perplexity < 80:
53
- result = "Probably AI-generated*"
54
  else:
55
- result = "Human-written*"
56
 
57
  return result, perplexity
58
 
@@ -63,7 +68,9 @@ async def analyze_text(data: TextInput):
63
  if not user_input:
64
  raise HTTPException(status_code=400, detail="Text cannot be empty")
65
 
66
- result, perplexity = classify_text(user_input)
 
 
67
  return {
68
  "result": result,
69
  "perplexity": round(perplexity, 2),
 
3
  from fastapi import FastAPI, HTTPException
4
  from pydantic import BaseModel
5
  from contextlib import asynccontextmanager
6
+ import asyncio
7
 
8
  # FastAPI app instance
9
  app = FastAPI()
10
 
11
+ # Global model and tokenizer variables
12
  model, tokenizer = None, None
13
 
14
  # Function to load model and tokenizer
 
16
  model_path = "./Ai-Text-Detector/model"
17
  weights_path = "./Ai-Text-Detector/model_weights.pth"
18
 
19
+ try:
20
+ tokenizer = GPT2TokenizerFast.from_pretrained(model_path)
21
+ config = GPT2Config.from_pretrained(model_path)
22
+ model = GPT2LMHeadModel(config)
23
+ model.load_state_dict(torch.load(weights_path, map_location=torch.device("cpu")))
24
+ model.eval() # Set model to evaluation mode
25
+ except Exception as e:
26
+ raise RuntimeError(f"Error loading model: {str(e)}")
27
+
28
  return model, tokenizer
29
 
30
+ # Load model on app startup
31
  @asynccontextmanager
32
  async def lifespan(app: FastAPI):
33
  global model, tokenizer
 
53
  perplexity = torch.exp(loss).item()
54
 
55
  if perplexity < 60:
56
+ result = "AI-generated"
57
  elif perplexity < 80:
58
+ result = "Probably AI-generated"
59
  else:
60
+ result = "Human-written"
61
 
62
  return result, perplexity
63
 
 
68
  if not user_input:
69
  raise HTTPException(status_code=400, detail="Text cannot be empty")
70
 
71
+ # Run classification asynchronously to prevent blocking
72
+ result, perplexity = await asyncio.to_thread(classify_text, user_input)
73
+
74
  return {
75
  "result": result,
76
  "perplexity": round(perplexity, 2),