Jobsforce commited on
Commit
5bcd567
·
verified ·
1 Parent(s): b4a7976

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -75
app.py CHANGED
@@ -5,33 +5,35 @@ import torch.nn.functional as F
5
  import threading
6
  import time
7
  import queue
8
- from nltk.tokenize import sent_tokenize
9
- # import nltk
10
- # try:
11
- # nltk.data.find('tokenizers/punkt')
12
- # except LookupError:
13
- # nltk.download('punkt')
14
-
15
 
16
  app = Flask(__name__)
17
 
18
-
19
- model_name = "priyabrat/AI.or.Human.text.classification"
20
- tokenizer = AutoTokenizer.from_pretrained(model_name)
21
- model = AutoModelForSequenceClassification.from_pretrained(model_name)
22
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
- model.to(device).eval()
24
-
 
 
 
 
 
 
 
 
25
  labels = ["AI-generated", "Human-written"]
26
  lock = threading.Lock()
27
 
28
-
29
  sessions = {}
30
  queues = {}
31
 
32
  def classify_line(text):
33
  with lock, torch.no_grad():
34
- inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=10000)
35
  inputs = {k: v.to(device) for k, v in inputs.items()}
36
  outputs = model(**inputs)
37
  probs = F.softmax(outputs.logits, dim=-1)
@@ -43,47 +45,31 @@ def classify_line(text):
43
  "confidence": round(confidence * 100, 2)
44
  }
45
 
46
-
47
-
48
  def background_worker(user_id, text):
49
  sessions[user_id]['status'] = "processing"
50
- if '\n' not in text:
51
- lines = sent_tokenize(text)
52
- else:
53
- lines = [line.strip() for line in text.strip().split('\n') if line.strip()]
54
-
55
- result_count = 0
56
-
57
- for i, line in enumerate(lines, 1):
58
- result = classify_line(line)
59
- result["line"] = i
60
- queues[user_id].put(f"data: {result}\n\n")
61
- result_count += 1
62
- time.sleep(0.2)
63
-
64
- queues[user_id].put("event: done\ndata: Session complete\n\n")
65
- sessions[user_id]['status'] = "done"
66
-
67
- time.sleep(2)
68
- del sessions[user_id]
69
- del queues[user_id]
70
-
71
- sessions[user_id]['status'] = "processing"
72
- lines = [line.strip() for line in text.strip().split('\n') if line.strip()]
73
- result_count = 0
74
-
75
- for i, line in enumerate(lines, 1):
76
- result = classify_line(line)
77
- result["line"] = i
78
- queues[user_id].put(f"data: {result}\n\n")
79
- result_count += 1
80
- time.sleep(0.2)
81
-
82
- queues[user_id].put("event: done\ndata: Session complete\n\n")
83
- sessions[user_id]['status'] = "done"
84
- time.sleep(2)
85
- del sessions[user_id]
86
- del queues[user_id]
87
 
88
  @app.route('/start-session', methods=['POST'])
89
  def start_session():
@@ -95,8 +81,7 @@ def start_session():
95
  return jsonify({"error": "user_id and text are required"}), 400
96
 
97
  if user_id in sessions:
98
- status = sessions[user_id]["status"]
99
- return jsonify({"message": f"Session already exists", "status": status}), 409
100
 
101
  sessions[user_id] = {"status": "pending"}
102
  queues[user_id] = queue.Queue()
@@ -112,34 +97,23 @@ def stream(user_id):
112
  def event_stream():
113
  while True:
114
  try:
115
- message = queues[user_id].get(timeout=60)
116
  yield message
117
- if "event: done" in message:
118
  break
119
  except queue.Empty:
120
  yield "event: timeout\ndata: No activity\n\n"
121
  break
122
 
123
- return Response(
124
- event_stream(),
125
- mimetype="text/event-stream",
126
- headers={
127
- "Cache-Control": "no-cache",
128
- "Connection": "keep-alive",
129
- "Access-Control-Allow-Origin": "*"
130
- }
131
- )
132
  @app.route('/status/<user_id>')
133
  def session_status(user_id):
134
- if user_id not in sessions:
135
- return jsonify({"status": "no_session"})
136
- return jsonify({
137
- "status": sessions[user_id]["status"]
138
- })
139
 
140
  @app.route('/')
141
  def index():
142
- return "alive yet !"
143
 
144
  if __name__ == '__main__':
145
- app.run(threaded=True,host='0.0.0.0', port=8080)
 
5
  import threading
6
  import time
7
  import queue
8
+ from nltk.tokenize import sent_tokenize
9
+ import os
 
 
 
 
 
10
 
11
  app = Flask(__name__)
12
 
13
+ # Health check endpoint
14
+ @app.route('/health')
15
+ def health_check():
16
+ return jsonify({"status": "healthy"}), 200
17
+
18
+ # Initialize model only when needed
19
+ def load_model():
20
+ model_name = "priyabrat/AI.or.Human.text.classification"
21
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
22
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
23
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
+ model.to(device).eval()
25
+ return tokenizer, model
26
+
27
+ tokenizer, model = load_model()
28
  labels = ["AI-generated", "Human-written"]
29
  lock = threading.Lock()
30
 
 
31
  sessions = {}
32
  queues = {}
33
 
34
  def classify_line(text):
35
  with lock, torch.no_grad():
36
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512) # Reduced max_length
37
  inputs = {k: v.to(device) for k, v in inputs.items()}
38
  outputs = model(**inputs)
39
  probs = F.softmax(outputs.logits, dim=-1)
 
45
  "confidence": round(confidence * 100, 2)
46
  }
47
 
 
 
48
  def background_worker(user_id, text):
49
  sessions[user_id]['status'] = "processing"
50
+
51
+ try:
52
+ if '\n' not in text:
53
+ lines = sent_tokenize(text)
54
+ else:
55
+ lines = [line.strip() for line in text.strip().split('\n') if line.strip()]
56
+
57
+ for i, line in enumerate(lines, 1):
58
+ result = classify_line(line)
59
+ result["line"] = i
60
+ queues[user_id].put(f"data: {json.dumps(result)}\n\n")
61
+ time.sleep(0.1) # Reduced delay
62
+
63
+ queues[user_id].put("event: done\ndata: Session complete\n\n")
64
+ except Exception as e:
65
+ queues[user_id].put(f"event: error\ndata: {str(e)}\n\n")
66
+ finally:
67
+ sessions[user_id]['status'] = "done"
68
+ time.sleep(1)
69
+ if user_id in sessions:
70
+ del sessions[user_id]
71
+ if user_id in queues:
72
+ del queues[user_id]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
  @app.route('/start-session', methods=['POST'])
75
  def start_session():
 
81
  return jsonify({"error": "user_id and text are required"}), 400
82
 
83
  if user_id in sessions:
84
+ return jsonify({"message": "Session already exists", "status": sessions[user_id]["status"]}), 409
 
85
 
86
  sessions[user_id] = {"status": "pending"}
87
  queues[user_id] = queue.Queue()
 
97
  def event_stream():
98
  while True:
99
  try:
100
+ message = queues[user_id].get(timeout=30) # Reduced timeout
101
  yield message
102
+ if "event: done" in message or "event: error" in message:
103
  break
104
  except queue.Empty:
105
  yield "event: timeout\ndata: No activity\n\n"
106
  break
107
 
108
+ return Response(event_stream(), mimetype="text/event-stream")
109
+
 
 
 
 
 
 
 
110
  @app.route('/status/<user_id>')
111
  def session_status(user_id):
112
+ return jsonify({"status": sessions.get(user_id, {}).get("status", "no_session")})
 
 
 
 
113
 
114
  @app.route('/')
115
  def index():
116
+ return "Server is running!"
117
 
118
  if __name__ == '__main__':
119
+ app.run(host='0.0.0.0', port=int(os.environ.get('PORT', 8080)))