| from flask import Flask, request, jsonify, Response |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification |
| import torch |
| import torch.nn.functional as F |
| import threading |
| import time |
| import queue |
| from nltk.tokenize import sent_tokenize |
| import os |
| import json |
| import logging |
|
|
| app = Flask(__name__) |
|
|
| |
| logging.basicConfig( |
| level=logging.INFO, |
| format='%(asctime)s %(levelname)s %(threadName)s %(message)s' |
| ) |
| logger = logging.getLogger(__name__) |
|
|
| |
| model = None |
| tokenizer = None |
| device = None |
| labels = ["AI-generated", "Human-written"] |
| lock = threading.Lock() |
|
|
| sessions = {} |
| queues = {} |
|
|
| @app.route('/') |
| def index(): |
| logger.info("Index page requested") |
| return "Server is running!" |
|
|
| @app.route('/health') |
| def health_check(): |
| logger.info("Health check requested") |
| return jsonify({"status": "healthy"}), 200 |
|
|
| def load_model(): |
| global tokenizer, model, device |
| if model is None or tokenizer is None: |
| model_name = "priyabrat/AI.or.Human.text.classification" |
| logger.info(f"Loading model and tokenizer from {model_name}") |
| tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir='/app/hf_cache') |
| model = AutoModelForSequenceClassification.from_pretrained(model_name, cache_dir='/app/hf_cache') |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| model.to(device).eval() |
| logger.info(f"Model loaded on device: {device}") |
| else: |
| logger.info("Model already loaded.") |
|
|
| def classify_line(text): |
| with lock, torch.no_grad(): |
| load_model() |
| inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512) |
| inputs = {k: v.to(device) for k, v in inputs.items()} |
| outputs = model(**inputs) |
| probs = F.softmax(outputs.logits, dim=-1) |
| pred = torch.argmax(probs, dim=-1).item() |
| confidence = probs[0][pred].item() |
| return { |
| "text": text.strip(), |
| "label": labels[pred], |
| "confidence": round(confidence * 100, 2) |
| } |
|
|
| def background_worker(user_id, text): |
| logger.info(f"Processing started for user_id={user_id}") |
| sessions[user_id]['status'] = "processing" |
|
|
| try: |
| if '\n' not in text: |
| lines = sent_tokenize(text) |
| else: |
| lines = [line.strip() for line in text.strip().split('\n') if line.strip()] |
|
|
| for i, line in enumerate(lines, 1): |
| result = classify_line(line) |
| logger.info(f"user_id={user_id} line={i} classified as {result['label']} ({result['confidence']}%)") |
| result["line"] = i |
| queues[user_id].put(f"data: {json.dumps(result)}\n\n") |
| time.sleep(0.1) |
|
|
| queues[user_id].put("event: done\ndata: Session complete\n\n") |
| except Exception as e: |
| logger.error(f"Error processing user_id={user_id}: {e}") |
| queues[user_id].put(f"event: error\ndata: {str(e)}\n\n") |
| finally: |
| sessions[user_id]['status'] = "done" |
| logger.info(f"Processing finished for user_id={user_id}") |
| time.sleep(1) |
| sessions.pop(user_id, None) |
| queues.pop(user_id, None) |
|
|
| @app.route('/start-session', methods=['POST']) |
| def start_session(): |
| data = request.get_json() |
| user_id = data.get("user_id") |
| text = data.get("text") |
|
|
| if not user_id or not text: |
| logger.warning("Missing user_id or text in start-session request") |
| return jsonify({"error": "user_id and text are required"}), 400 |
|
|
| if user_id in sessions: |
| logger.warning(f"Session already exists for user_id={user_id}") |
| return jsonify({"message": "Session already exists", "status": sessions[user_id]["status"]}), 409 |
|
|
| logger.info(f"Starting session for user_id={user_id}") |
| sessions[user_id] = {"status": "pending"} |
| queues[user_id] = queue.Queue() |
| threading.Thread(target=background_worker, args=(user_id, text), daemon=True).start() |
|
|
| return jsonify({"message": "Session started", "status": "pending"}), 202 |
|
|
| @app.route('/stream/<user_id>') |
| def stream(user_id): |
| if user_id not in sessions: |
| logger.warning(f"No active session for user_id={user_id} in stream request") |
| return jsonify({"error": "No active session for this user"}), 404 |
|
|
| def event_stream(): |
| while True: |
| try: |
| message = queues[user_id].get(timeout=30) |
| yield message |
| if "event: done" in message or "event: error" in message: |
| logger.info(f"Stream ended for user_id={user_id} with message: {message.strip()}") |
| break |
| except queue.Empty: |
| logger.warning(f"Stream timeout for user_id={user_id}") |
| yield "event: timeout\ndata: No activity\n\n" |
| break |
|
|
| return Response(event_stream(), mimetype="text/event-stream") |
|
|
| @app.route('/status/<user_id>') |
| def session_status(user_id): |
| status = sessions.get(user_id, {}).get("status", "no_session") |
| logger.info(f"Status request for user_id={user_id}: {status}") |
| return jsonify({"status": status}) |
|
|
| if __name__ == '__main__': |
| logger.info("Starting Flask app") |
| app.run(host='0.0.0.0', port=int(os.environ.get('PORT', 8080))) |
|
|