ravi848101 commited on
Commit
15b499b
·
1 Parent(s): d0b75fd

Add initial implementation of Flask app for AI text classification and requirements file

Browse files
Files changed (2) hide show
  1. app.py +145 -0
  2. requirenments.txt +31 -0
app.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify, Response
2
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
+ import torch
4
+ import torch.nn.functional as F
5
+ import threading
6
+ import time
7
+ import queue
8
+ from nltk.tokenize import sent_tokenize
9
+ import nltk
10
+ try:
11
+ nltk.data.find('tokenizers/punkt')
12
+ except LookupError:
13
+ nltk.download('punkt')
14
+
15
+
16
+ app = Flask(__name__)
17
+
18
+
19
+ model_name = "priyabrat/AI.or.Human.text.classification"
20
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
21
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
22
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
+ model.to(device).eval()
24
+
25
+ labels = ["AI-generated", "Human-written"]
26
+ lock = threading.Lock()
27
+
28
+
29
+ sessions = {}
30
+ queues = {}
31
+
32
+ def classify_line(text):
33
+ with lock, torch.no_grad():
34
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=10000)
35
+ inputs = {k: v.to(device) for k, v in inputs.items()}
36
+ outputs = model(**inputs)
37
+ probs = F.softmax(outputs.logits, dim=-1)
38
+ pred = torch.argmax(probs, dim=-1).item()
39
+ confidence = probs[0][pred].item()
40
+ return {
41
+ "text": text.strip(),
42
+ "label": labels[pred],
43
+ "confidence": round(confidence * 100, 2)
44
+ }
45
+
46
+
47
+
48
+ def background_worker(user_id, text):
49
+ sessions[user_id]['status'] = "processing"
50
+ if '\n' not in text:
51
+ lines = sent_tokenize(text)
52
+ else:
53
+ lines = [line.strip() for line in text.strip().split('\n') if line.strip()]
54
+
55
+ result_count = 0
56
+
57
+ for i, line in enumerate(lines, 1):
58
+ result = classify_line(line)
59
+ result["line"] = i
60
+ queues[user_id].put(f"data: {result}\n\n")
61
+ result_count += 1
62
+ time.sleep(0.2)
63
+
64
+ queues[user_id].put("event: done\ndata: Session complete\n\n")
65
+ sessions[user_id]['status'] = "done"
66
+
67
+ time.sleep(2)
68
+ del sessions[user_id]
69
+ del queues[user_id]
70
+
71
+ sessions[user_id]['status'] = "processing"
72
+ lines = [line.strip() for line in text.strip().split('\n') if line.strip()]
73
+ result_count = 0
74
+
75
+ for i, line in enumerate(lines, 1):
76
+ result = classify_line(line)
77
+ result["line"] = i
78
+ queues[user_id].put(f"data: {result}\n\n")
79
+ result_count += 1
80
+ time.sleep(0.2)
81
+
82
+ queues[user_id].put("event: done\ndata: Session complete\n\n")
83
+ sessions[user_id]['status'] = "done"
84
+ time.sleep(2)
85
+ del sessions[user_id]
86
+ del queues[user_id]
87
+
88
+ @app.route('/start-session', methods=['POST'])
89
+ def start_session():
90
+ data = request.get_json()
91
+ user_id = data.get("user_id")
92
+ text = data.get("text")
93
+
94
+ if not user_id or not text:
95
+ return jsonify({"error": "user_id and text are required"}), 400
96
+
97
+ if user_id in sessions:
98
+ status = sessions[user_id]["status"]
99
+ return jsonify({"message": f"Session already exists", "status": status}), 409
100
+
101
+ sessions[user_id] = {"status": "pending"}
102
+ queues[user_id] = queue.Queue()
103
+ threading.Thread(target=background_worker, args=(user_id, text), daemon=True).start()
104
+
105
+ return jsonify({"message": "Session started", "status": "pending"}), 202
106
+
107
+ @app.route('/stream/<user_id>')
108
+ def stream(user_id):
109
+ if user_id not in sessions:
110
+ return jsonify({"error": "No active session for this user"}), 404
111
+
112
+ def event_stream():
113
+ while True:
114
+ try:
115
+ message = queues[user_id].get(timeout=60)
116
+ yield message
117
+ if "event: done" in message:
118
+ break
119
+ except queue.Empty:
120
+ yield "event: timeout\ndata: No activity\n\n"
121
+ break
122
+
123
+ return Response(
124
+ event_stream(),
125
+ mimetype="text/event-stream",
126
+ headers={
127
+ "Cache-Control": "no-cache",
128
+ "Connection": "keep-alive",
129
+ "Access-Control-Allow-Origin": "*"
130
+ }
131
+ )
132
+ @app.route('/status/<user_id>')
133
+ def session_status(user_id):
134
+ if user_id not in sessions:
135
+ return jsonify({"status": "no_session"})
136
+ return jsonify({
137
+ "status": sessions[user_id]["status"]
138
+ })
139
+
140
+ @app.route('/')
141
+ def index():
142
+ return "alive yet !"
143
+
144
+ if __name__ == '__main__':
145
+ app.run(debug=True, threaded=True,host='0.0.0.0', port=5000)
requirenments.txt ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ blinker==1.9.0
2
+ certifi==2025.4.26
3
+ charset-normalizer==3.4.2
4
+ click==8.2.1
5
+ colorama==0.4.6
6
+ filelock==3.18.0
7
+ Flask==3.1.1
8
+ fsspec==2025.5.1
9
+ huggingface-hub==0.32.3
10
+ idna==3.10
11
+ itsdangerous==2.2.0
12
+ Jinja2==3.1.6
13
+ joblib==1.5.1
14
+ MarkupSafe==3.0.2
15
+ mpmath==1.3.0
16
+ networkx==3.4.2
17
+ nltk==3.9.1
18
+ numpy==2.2.6
19
+ packaging==25.0
20
+ PyYAML==6.0.2
21
+ regex==2024.11.6
22
+ requests==2.32.3
23
+ safetensors==0.5.3
24
+ sympy==1.14.0
25
+ tokenizers==0.21.1
26
+ torch==2.7.0
27
+ tqdm==4.67.1
28
+ transformers==4.52.4
29
+ typing_extensions==4.13.2
30
+ urllib3==2.4.0
31
+ Werkzeug==3.1.3