plagcheck / app.py
Jobsforce's picture
Update app.py
ae41e6f verified
raw
history blame
5.28 kB
from flask import Flask, request, jsonify, Response
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import torch.nn.functional as F
import threading
import time
import queue
from nltk.tokenize import sent_tokenize
import os
import json
import logging
app = Flask(__name__)
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s %(levelname)s %(threadName)s %(message)s'
)
logger = logging.getLogger(__name__)
# Lazy-loaded shared state
model = None
tokenizer = None
device = None
labels = ["AI-generated", "Human-written"]
lock = threading.Lock()
sessions = {}
queues = {}
@app.route('/')
def index():
logger.info("Index page requested")
return "Server is running!"
@app.route('/health')
def health_check():
logger.info("Health check requested")
return jsonify({"status": "healthy"}), 200
def load_model():
global tokenizer, model, device
if model is None or tokenizer is None:
model_name = "priyabrat/AI.or.Human.text.classification"
logger.info(f"Loading model and tokenizer from {model_name}")
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir='/app/hf_cache')
model = AutoModelForSequenceClassification.from_pretrained(model_name, cache_dir='/app/hf_cache')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device).eval()
logger.info(f"Model loaded on device: {device}")
else:
logger.info("Model already loaded.")
def classify_line(text):
with lock, torch.no_grad():
load_model()
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
inputs = {k: v.to(device) for k, v in inputs.items()}
outputs = model(**inputs)
probs = F.softmax(outputs.logits, dim=-1)
pred = torch.argmax(probs, dim=-1).item()
confidence = probs[0][pred].item()
return {
"text": text.strip(),
"label": labels[pred],
"confidence": round(confidence * 100, 2)
}
def background_worker(user_id, text):
logger.info(f"Processing started for user_id={user_id}")
sessions[user_id]['status'] = "processing"
try:
if '\n' not in text:
lines = sent_tokenize(text)
else:
lines = [line.strip() for line in text.strip().split('\n') if line.strip()]
for i, line in enumerate(lines, 1):
result = classify_line(line)
logger.info(f"user_id={user_id} line={i} classified as {result['label']} ({result['confidence']}%)")
result["line"] = i
queues[user_id].put(f"data: {json.dumps(result)}\n\n")
time.sleep(0.1)
queues[user_id].put("event: done\ndata: Session complete\n\n")
except Exception as e:
logger.error(f"Error processing user_id={user_id}: {e}")
queues[user_id].put(f"event: error\ndata: {str(e)}\n\n")
finally:
sessions[user_id]['status'] = "done"
logger.info(f"Processing finished for user_id={user_id}")
time.sleep(1)
sessions.pop(user_id, None)
queues.pop(user_id, None)
@app.route('/start-session', methods=['POST'])
def start_session():
data = request.get_json()
user_id = data.get("user_id")
text = data.get("text")
if not user_id or not text:
logger.warning("Missing user_id or text in start-session request")
return jsonify({"error": "user_id and text are required"}), 400
if user_id in sessions:
logger.warning(f"Session already exists for user_id={user_id}")
return jsonify({"message": "Session already exists", "status": sessions[user_id]["status"]}), 409
logger.info(f"Starting session for user_id={user_id}")
sessions[user_id] = {"status": "pending"}
queues[user_id] = queue.Queue()
threading.Thread(target=background_worker, args=(user_id, text), daemon=True).start()
return jsonify({"message": "Session started", "status": "pending"}), 202
@app.route('/stream/<user_id>')
def stream(user_id):
if user_id not in sessions:
logger.warning(f"No active session for user_id={user_id} in stream request")
return jsonify({"error": "No active session for this user"}), 404
def event_stream():
while True:
try:
message = queues[user_id].get(timeout=30)
yield message
if "event: done" in message or "event: error" in message:
logger.info(f"Stream ended for user_id={user_id} with message: {message.strip()}")
break
except queue.Empty:
logger.warning(f"Stream timeout for user_id={user_id}")
yield "event: timeout\ndata: No activity\n\n"
break
return Response(event_stream(), mimetype="text/event-stream")
@app.route('/status/<user_id>')
def session_status(user_id):
status = sessions.get(user_id, {}).get("status", "no_session")
logger.info(f"Status request for user_id={user_id}: {status}")
return jsonify({"status": status})
if __name__ == '__main__':
logger.info("Starting Flask app")
app.run(host='0.0.0.0', port=int(os.environ.get('PORT', 8080)))