Jobsforce commited on
Commit
9fde7ed
·
verified ·
1 Parent(s): 7cc7849

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -10
app.py CHANGED
@@ -7,25 +7,37 @@ import time
7
  import queue
8
  from nltk.tokenize import sent_tokenize
9
  import os
 
 
10
 
11
  app = Flask(__name__)
12
 
 
 
 
 
 
 
 
13
  # Health check endpoint
14
  @app.route('/health')
15
  def health_check():
 
16
  return jsonify({"status": "healthy"}), 200
17
 
18
  # Initialize model only when needed
19
  def load_model():
20
  model_name = "priyabrat/AI.or.Human.text.classification"
21
- tokenizer = AutoTokenizer.from_pretrained('priyabrat/AI.or.Human.text.classification', cache_dir='/app/hf_cache')
22
- model = AutoModelForSequenceClassification.from_pretrained('priyabrat/AI.or.Human.text.classification', cache_dir='/app/hf_cache')
 
23
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
  model.to(device).eval()
25
- return tokenizer, model
 
26
 
27
 
28
- tokenizer, model = load_model()
29
  labels = ["AI-generated", "Human-written"]
30
  lock = threading.Lock()
31
 
@@ -34,7 +46,7 @@ queues = {}
34
 
35
  def classify_line(text):
36
  with lock, torch.no_grad():
37
- inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512) # Reduced max_length
38
  inputs = {k: v.to(device) for k, v in inputs.items()}
39
  outputs = model(**inputs)
40
  probs = F.softmax(outputs.logits, dim=-1)
@@ -47,25 +59,29 @@ def classify_line(text):
47
  }
48
 
49
  def background_worker(user_id, text):
 
50
  sessions[user_id]['status'] = "processing"
51
 
52
  try:
53
  if '\n' not in text:
54
- lines = sent_tokenize(text)
55
  else:
56
  lines = [line.strip() for line in text.strip().split('\n') if line.strip()]
57
 
58
  for i, line in enumerate(lines, 1):
59
  result = classify_line(line)
 
60
  result["line"] = i
61
  queues[user_id].put(f"data: {json.dumps(result)}\n\n")
62
- time.sleep(0.1) # Reduced delay
63
 
64
  queues[user_id].put("event: done\ndata: Session complete\n\n")
65
  except Exception as e:
 
66
  queues[user_id].put(f"event: error\ndata: {str(e)}\n\n")
67
  finally:
68
  sessions[user_id]['status'] = "done"
 
69
  time.sleep(1)
70
  if user_id in sessions:
71
  del sessions[user_id]
@@ -79,11 +95,14 @@ def start_session():
79
  text = data.get("text")
80
 
81
  if not user_id or not text:
 
82
  return jsonify({"error": "user_id and text are required"}), 400
83
 
84
  if user_id in sessions:
 
85
  return jsonify({"message": "Session already exists", "status": sessions[user_id]["status"]}), 409
86
 
 
87
  sessions[user_id] = {"status": "pending"}
88
  queues[user_id] = queue.Queue()
89
  threading.Thread(target=background_worker, args=(user_id, text), daemon=True).start()
@@ -93,16 +112,19 @@ def start_session():
93
  @app.route('/stream/<user_id>')
94
  def stream(user_id):
95
  if user_id not in sessions:
 
96
  return jsonify({"error": "No active session for this user"}), 404
97
 
98
  def event_stream():
99
  while True:
100
  try:
101
- message = queues[user_id].get(timeout=30) # Reduced timeout
102
  yield message
103
  if "event: done" in message or "event: error" in message:
 
104
  break
105
  except queue.Empty:
 
106
  yield "event: timeout\ndata: No activity\n\n"
107
  break
108
 
@@ -110,11 +132,15 @@ def stream(user_id):
110
 
111
  @app.route('/status/<user_id>')
112
  def session_status(user_id):
113
- return jsonify({"status": sessions.get(user_id, {}).get("status", "no_session")})
 
 
114
 
115
  @app.route('/')
116
  def index():
 
117
  return "Server is running!"
118
 
119
  if __name__ == '__main__':
120
- app.run(host='0.0.0.0', port=int(os.environ.get('PORT', 8080)))
 
 
7
  import queue
8
  from nltk.tokenize import sent_tokenize
9
  import os
10
+ import json
11
+ import logging
12
 
13
  app = Flask(__name__)
14
 
15
+ # Configure logging
16
+ logging.basicConfig(
17
+ level=logging.INFO,
18
+ format='%(asctime)s %(levelname)s %(threadName)s %(message)s'
19
+ )
20
+ logger = logging.getLogger(__name__)
21
+
22
  # Health check endpoint
23
  @app.route('/health')
24
  def health_check():
25
+ logger.info("Health check requested")
26
  return jsonify({"status": "healthy"}), 200
27
 
28
  # Initialize model only when needed
29
  def load_model():
30
  model_name = "priyabrat/AI.or.Human.text.classification"
31
+ logger.info(f"Loading model and tokenizer from {model_name}")
32
+ tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir='/app/hf_cache')
33
+ model = AutoModelForSequenceClassification.from_pretrained(model_name, cache_dir='/app/hf_cache')
34
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
  model.to(device).eval()
36
+ logger.info(f"Model loaded on device: {device}")
37
+ return tokenizer, model, device
38
 
39
 
40
+ tokenizer, model, device = load_model()
41
  labels = ["AI-generated", "Human-written"]
42
  lock = threading.Lock()
43
 
 
46
 
47
  def classify_line(text):
48
  with lock, torch.no_grad():
49
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
50
  inputs = {k: v.to(device) for k, v in inputs.items()}
51
  outputs = model(**inputs)
52
  probs = F.softmax(outputs.logits, dim=-1)
 
59
  }
60
 
61
  def background_worker(user_id, text):
62
+ logger.info(f"Processing started for user_id={user_id}")
63
  sessions[user_id]['status'] = "processing"
64
 
65
  try:
66
  if '\n' not in text:
67
+ lines = sent_tokenize(text)
68
  else:
69
  lines = [line.strip() for line in text.strip().split('\n') if line.strip()]
70
 
71
  for i, line in enumerate(lines, 1):
72
  result = classify_line(line)
73
+ logger.info(f"user_id={user_id} line={i} classified as {result['label']} ({result['confidence']}%)")
74
  result["line"] = i
75
  queues[user_id].put(f"data: {json.dumps(result)}\n\n")
76
+ time.sleep(0.1)
77
 
78
  queues[user_id].put("event: done\ndata: Session complete\n\n")
79
  except Exception as e:
80
+ logger.error(f"Error processing user_id={user_id}: {e}")
81
  queues[user_id].put(f"event: error\ndata: {str(e)}\n\n")
82
  finally:
83
  sessions[user_id]['status'] = "done"
84
+ logger.info(f"Processing finished for user_id={user_id}")
85
  time.sleep(1)
86
  if user_id in sessions:
87
  del sessions[user_id]
 
95
  text = data.get("text")
96
 
97
  if not user_id or not text:
98
+ logger.warning("Missing user_id or text in start-session request")
99
  return jsonify({"error": "user_id and text are required"}), 400
100
 
101
  if user_id in sessions:
102
+ logger.warning(f"Session already exists for user_id={user_id}")
103
  return jsonify({"message": "Session already exists", "status": sessions[user_id]["status"]}), 409
104
 
105
+ logger.info(f"Starting session for user_id={user_id}")
106
  sessions[user_id] = {"status": "pending"}
107
  queues[user_id] = queue.Queue()
108
  threading.Thread(target=background_worker, args=(user_id, text), daemon=True).start()
 
112
  @app.route('/stream/<user_id>')
113
  def stream(user_id):
114
  if user_id not in sessions:
115
+ logger.warning(f"No active session for user_id={user_id} in stream request")
116
  return jsonify({"error": "No active session for this user"}), 404
117
 
118
  def event_stream():
119
  while True:
120
  try:
121
+ message = queues[user_id].get(timeout=30)
122
  yield message
123
  if "event: done" in message or "event: error" in message:
124
+ logger.info(f"Stream ended for user_id={user_id} with message: {message.strip()}")
125
  break
126
  except queue.Empty:
127
+ logger.warning(f"Stream timeout for user_id={user_id}")
128
  yield "event: timeout\ndata: No activity\n\n"
129
  break
130
 
 
132
 
133
  @app.route('/status/<user_id>')
134
  def session_status(user_id):
135
+ status = sessions.get(user_id, {}).get("status", "no_session")
136
+ logger.info(f"Status request for user_id={user_id}: {status}")
137
+ return jsonify({"status": status})
138
 
139
  @app.route('/')
140
  def index():
141
+ logger.info("Index page requested")
142
  return "Server is running!"
143
 
144
  if __name__ == '__main__':
145
+ logger.info("Starting Flask app")
146
+ app.run(host='0.0.0.0', port=int(os.environ.get('PORT', 8080)))