Jobsforce commited on
Commit
28ff626
·
verified ·
1 Parent(s): ae41e6f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -124
app.py CHANGED
@@ -1,148 +1,92 @@
1
- from flask import Flask, request, jsonify, Response
2
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
- import torch
4
- import torch.nn.functional as F
5
- import threading
6
- import time
7
- import queue
8
  from nltk.tokenize import sent_tokenize
9
- import os
10
- import json
11
  import logging
 
 
 
12
 
13
- app = Flask(__name__)
14
 
15
- # Configure logging
16
- logging.basicConfig(
17
- level=logging.INFO,
18
- format='%(asctime)s %(levelname)s %(threadName)s %(message)s'
19
- )
20
  logger = logging.getLogger(__name__)
21
 
22
- # Lazy-loaded shared state
23
- model = None
24
- tokenizer = None
25
- device = None
26
- labels = ["AI-generated", "Human-written"]
27
- lock = threading.Lock()
28
-
29
- sessions = {}
30
- queues = {}
31
-
32
- @app.route('/')
33
- def index():
34
- logger.info("Index page requested")
35
- return "Server is running!"
36
-
37
- @app.route('/health')
38
- def health_check():
39
- logger.info("Health check requested")
40
- return jsonify({"status": "healthy"}), 200
41
-
42
- def load_model():
43
- global tokenizer, model, device
44
- if model is None or tokenizer is None:
45
- model_name = "priyabrat/AI.or.Human.text.classification"
46
- logger.info(f"Loading model and tokenizer from {model_name}")
47
- tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir='/app/hf_cache')
48
- model = AutoModelForSequenceClassification.from_pretrained(model_name, cache_dir='/app/hf_cache')
49
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
50
- model.to(device).eval()
51
- logger.info(f"Model loaded on device: {device}")
52
- else:
53
- logger.info("Model already loaded.")
54
-
55
- def classify_line(text):
56
- with lock, torch.no_grad():
57
- load_model()
58
- inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
59
- inputs = {k: v.to(device) for k, v in inputs.items()}
60
- outputs = model(**inputs)
61
- probs = F.softmax(outputs.logits, dim=-1)
62
- pred = torch.argmax(probs, dim=-1).item()
63
- confidence = probs[0][pred].item()
64
- return {
65
- "text": text.strip(),
66
- "label": labels[pred],
67
- "confidence": round(confidence * 100, 2)
68
- }
69
-
70
- def background_worker(user_id, text):
71
- logger.info(f"Processing started for user_id={user_id}")
72
- sessions[user_id]['status'] = "processing"
73
 
74
- try:
75
- if '\n' not in text:
76
- lines = sent_tokenize(text)
77
- else:
78
- lines = [line.strip() for line in text.strip().split('\n') if line.strip()]
79
-
80
- for i, line in enumerate(lines, 1):
81
- result = classify_line(line)
82
- logger.info(f"user_id={user_id} line={i} classified as {result['label']} ({result['confidence']}%)")
83
- result["line"] = i
84
- queues[user_id].put(f"data: {json.dumps(result)}\n\n")
85
- time.sleep(0.1)
86
-
87
- queues[user_id].put("event: done\ndata: Session complete\n\n")
88
- except Exception as e:
89
- logger.error(f"Error processing user_id={user_id}: {e}")
90
- queues[user_id].put(f"event: error\ndata: {str(e)}\n\n")
91
- finally:
92
- sessions[user_id]['status'] = "done"
93
- logger.info(f"Processing finished for user_id={user_id}")
94
- time.sleep(1)
95
- sessions.pop(user_id, None)
96
- queues.pop(user_id, None)
97
 
98
- @app.route('/start-session', methods=['POST'])
99
- def start_session():
100
- data = request.get_json()
101
  user_id = data.get("user_id")
102
  text = data.get("text")
103
 
104
  if not user_id or not text:
105
- logger.warning("Missing user_id or text in start-session request")
106
- return jsonify({"error": "user_id and text are required"}), 400
107
 
108
  if user_id in sessions:
109
- logger.warning(f"Session already exists for user_id={user_id}")
110
- return jsonify({"message": "Session already exists", "status": sessions[user_id]["status"]}), 409
111
 
112
- logger.info(f"Starting session for user_id={user_id}")
113
- sessions[user_id] = {"status": "pending"}
114
- queues[user_id] = queue.Queue()
115
- threading.Thread(target=background_worker, args=(user_id, text), daemon=True).start()
116
 
117
- return jsonify({"message": "Session started", "status": "pending"}), 202
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
- @app.route('/stream/<user_id>')
120
- def stream(user_id):
121
  if user_id not in sessions:
122
- logger.warning(f"No active session for user_id={user_id} in stream request")
123
- return jsonify({"error": "No active session for this user"}), 404
124
 
125
- def event_stream():
126
  while True:
127
- try:
128
- message = queues[user_id].get(timeout=30)
129
- yield message
130
- if "event: done" in message or "event: error" in message:
131
- logger.info(f"Stream ended for user_id={user_id} with message: {message.strip()}")
132
- break
133
- except queue.Empty:
134
- logger.warning(f"Stream timeout for user_id={user_id}")
135
- yield "event: timeout\ndata: No activity\n\n"
136
  break
 
137
 
138
- return Response(event_stream(), mimetype="text/event-stream")
139
 
140
- @app.route('/status/<user_id>')
141
- def session_status(user_id):
142
- status = sessions.get(user_id, {}).get("status", "no_session")
143
- logger.info(f"Status request for user_id={user_id}: {status}")
144
- return jsonify({"status": status})
145
 
146
- if __name__ == '__main__':
147
- logger.info("Starting Flask app")
148
- app.run(host='0.0.0.0', port=int(os.environ.get('PORT', 8080)))
 
1
+ from fastapi import FastAPI, Request, BackgroundTasks
2
+ from fastapi.responses import StreamingResponse, JSONResponse
3
+ from transformers import pipeline
 
 
 
 
4
  from nltk.tokenize import sent_tokenize
5
+ import asyncio
6
+ import uuid
7
  import logging
8
+ import json
9
+ from typing import Dict
10
+ from collections import deque
11
 
12
+ app = FastAPI()
13
 
14
+ # Logging
15
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')
 
 
 
16
  logger = logging.getLogger(__name__)
17
 
18
+ # Load pipeline once at startup
19
+ classifier = pipeline("text-classification", model="priyabrat/AI.or.Human.text.classification")
20
+ sessions: Dict[str, Dict] = {}
21
+ queues: Dict[str, deque] = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+ @app.get("/health")
24
+ async def health():
25
+ return {"status": "healthy"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
+ @app.post("/start-session")
28
+ async def start_session(request: Request, background_tasks: BackgroundTasks):
29
+ data = await request.json()
30
  user_id = data.get("user_id")
31
  text = data.get("text")
32
 
33
  if not user_id or not text:
34
+ return JSONResponse(content={"error": "user_id and text are required"}, status_code=400)
 
35
 
36
  if user_id in sessions:
37
+ return JSONResponse(content={"message": "Session already exists", "status": sessions[user_id]["status"]}, status_code=409)
 
38
 
39
+ sessions[user_id] = {"status": "processing"}
40
+ queues[user_id] = deque()
 
 
41
 
42
+ background_tasks.add_task(process_text, user_id, text)
43
+ return {"message": "Session started", "status": "processing"}
44
+
45
+ async def process_text(user_id: str, text: str):
46
+ try:
47
+ lines = sent_tokenize(text) if '\n' not in text else [l.strip() for l in text.strip().split('\n') if l.strip()]
48
+ for idx, line in enumerate(lines, 1):
49
+ result = classifier(line)[0]
50
+ label = result['label']
51
+ confidence = round(result['score'] * 100, 2)
52
+ payload = {
53
+ "line": idx,
54
+ "text": line,
55
+ "label": "AI-generated" if label == "LABEL_0" else "Human-written",
56
+ "confidence": confidence
57
+ }
58
+ queues[user_id].append(f"data: {json.dumps(payload)}\n\n")
59
+ await asyncio.sleep(0.1)
60
+
61
+ queues[user_id].append("event: done\ndata: Session complete\n\n")
62
+ except Exception as e:
63
+ logger.error(f"Error: {e}")
64
+ queues[user_id].append(f"event: error\ndata: {str(e)}\n\n")
65
+ finally:
66
+ sessions[user_id]['status'] = "done"
67
+ await asyncio.sleep(1)
68
+ sessions.pop(user_id, None)
69
+ queues.pop(user_id, None)
70
 
71
+ @app.get("/stream/{user_id}")
72
+ async def stream(user_id: str):
73
  if user_id not in sessions:
74
+ return JSONResponse(content={"error": "No active session"}, status_code=404)
 
75
 
76
+ async def event_stream():
77
  while True:
78
+ if queues[user_id]:
79
+ yield queues[user_id].popleft()
80
+ elif sessions[user_id]['status'] == 'done':
 
 
 
 
 
 
81
  break
82
+ await asyncio.sleep(0.1)
83
 
84
+ return StreamingResponse(event_stream(), media_type="text/event-stream")
85
 
86
+ @app.get("/status/{user_id}")
87
+ async def status(user_id: str):
88
+ return {"status": sessions.get(user_id, {}).get("status", "no_session")}
 
 
89
 
90
+ @app.get("/")
91
+ async def index():
92
+ return {"message": "Server is running"}