Jobsforce commited on
Commit
ae41e6f
·
verified ·
1 Parent(s): 395c41a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -28
app.py CHANGED
@@ -19,33 +19,42 @@ logging.basicConfig(
19
  )
20
  logger = logging.getLogger(__name__)
21
 
22
- # Health check endpoint
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  @app.route('/health')
24
  def health_check():
25
  logger.info("Health check requested")
26
  return jsonify({"status": "healthy"}), 200
27
 
28
- # Initialize model only when needed
29
  def load_model():
30
- model_name = "priyabrat/AI.or.Human.text.classification"
31
- logger.info(f"Loading model and tokenizer from {model_name}")
32
- tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir='/app/hf_cache')
33
- model = AutoModelForSequenceClassification.from_pretrained(model_name, cache_dir='/app/hf_cache')
34
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
- model.to(device).eval()
36
- logger.info(f"Model loaded on device: {device}")
37
- return tokenizer, model, device
38
-
39
-
40
- tokenizer, model, device = load_model()
41
- labels = ["AI-generated", "Human-written"]
42
- lock = threading.Lock()
43
-
44
- sessions = {}
45
- queues = {}
46
 
47
  def classify_line(text):
48
  with lock, torch.no_grad():
 
49
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
50
  inputs = {k: v.to(device) for k, v in inputs.items()}
51
  outputs = model(**inputs)
@@ -61,7 +70,7 @@ def classify_line(text):
61
  def background_worker(user_id, text):
62
  logger.info(f"Processing started for user_id={user_id}")
63
  sessions[user_id]['status'] = "processing"
64
-
65
  try:
66
  if '\n' not in text:
67
  lines = sent_tokenize(text)
@@ -83,10 +92,8 @@ def background_worker(user_id, text):
83
  sessions[user_id]['status'] = "done"
84
  logger.info(f"Processing finished for user_id={user_id}")
85
  time.sleep(1)
86
- if user_id in sessions:
87
- del sessions[user_id]
88
- if user_id in queues:
89
- del queues[user_id]
90
 
91
  @app.route('/start-session', methods=['POST'])
92
  def start_session():
@@ -136,11 +143,6 @@ def session_status(user_id):
136
  logger.info(f"Status request for user_id={user_id}: {status}")
137
  return jsonify({"status": status})
138
 
139
- @app.route('/')
140
- def index():
141
- logger.info("Index page requested")
142
- return "Server is running!"
143
-
144
  if __name__ == '__main__':
145
  logger.info("Starting Flask app")
146
  app.run(host='0.0.0.0', port=int(os.environ.get('PORT', 8080)))
 
19
  )
20
  logger = logging.getLogger(__name__)
21
 
22
+ # Lazy-loaded shared state
23
+ model = None
24
+ tokenizer = None
25
+ device = None
26
+ labels = ["AI-generated", "Human-written"]
27
+ lock = threading.Lock()
28
+
29
+ sessions = {}
30
+ queues = {}
31
+
32
+ @app.route('/')
33
+ def index():
34
+ logger.info("Index page requested")
35
+ return "Server is running!"
36
+
37
  @app.route('/health')
38
  def health_check():
39
  logger.info("Health check requested")
40
  return jsonify({"status": "healthy"}), 200
41
 
 
42
  def load_model():
43
+ global tokenizer, model, device
44
+ if model is None or tokenizer is None:
45
+ model_name = "priyabrat/AI.or.Human.text.classification"
46
+ logger.info(f"Loading model and tokenizer from {model_name}")
47
+ tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir='/app/hf_cache')
48
+ model = AutoModelForSequenceClassification.from_pretrained(model_name, cache_dir='/app/hf_cache')
49
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
50
+ model.to(device).eval()
51
+ logger.info(f"Model loaded on device: {device}")
52
+ else:
53
+ logger.info("Model already loaded.")
 
 
 
 
 
54
 
55
  def classify_line(text):
56
  with lock, torch.no_grad():
57
+ load_model()
58
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
59
  inputs = {k: v.to(device) for k, v in inputs.items()}
60
  outputs = model(**inputs)
 
70
  def background_worker(user_id, text):
71
  logger.info(f"Processing started for user_id={user_id}")
72
  sessions[user_id]['status'] = "processing"
73
+
74
  try:
75
  if '\n' not in text:
76
  lines = sent_tokenize(text)
 
92
  sessions[user_id]['status'] = "done"
93
  logger.info(f"Processing finished for user_id={user_id}")
94
  time.sleep(1)
95
+ sessions.pop(user_id, None)
96
+ queues.pop(user_id, None)
 
 
97
 
98
  @app.route('/start-session', methods=['POST'])
99
  def start_session():
 
143
  logger.info(f"Status request for user_id={user_id}: {status}")
144
  return jsonify({"status": status})
145
 
 
 
 
 
 
146
  if __name__ == '__main__':
147
  logger.info("Starting Flask app")
148
  app.run(host='0.0.0.0', port=int(os.environ.get('PORT', 8080)))