Update app.py
Browse files
app.py
CHANGED
|
@@ -1,148 +1,92 @@
|
|
| 1 |
-
from
|
| 2 |
-
from
|
| 3 |
-
import
|
| 4 |
-
import torch.nn.functional as F
|
| 5 |
-
import threading
|
| 6 |
-
import time
|
| 7 |
-
import queue
|
| 8 |
from nltk.tokenize import sent_tokenize
|
| 9 |
-
import
|
| 10 |
-
import
|
| 11 |
import logging
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
-
app =
|
| 14 |
|
| 15 |
-
#
|
| 16 |
-
logging.basicConfig(
|
| 17 |
-
level=logging.INFO,
|
| 18 |
-
format='%(asctime)s %(levelname)s %(threadName)s %(message)s'
|
| 19 |
-
)
|
| 20 |
logger = logging.getLogger(__name__)
|
| 21 |
|
| 22 |
-
#
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
labels = ["AI-generated", "Human-written"]
|
| 27 |
-
lock = threading.Lock()
|
| 28 |
-
|
| 29 |
-
sessions = {}
|
| 30 |
-
queues = {}
|
| 31 |
-
|
| 32 |
-
@app.route('/')
|
| 33 |
-
def index():
|
| 34 |
-
logger.info("Index page requested")
|
| 35 |
-
return "Server is running!"
|
| 36 |
-
|
| 37 |
-
@app.route('/health')
|
| 38 |
-
def health_check():
|
| 39 |
-
logger.info("Health check requested")
|
| 40 |
-
return jsonify({"status": "healthy"}), 200
|
| 41 |
-
|
| 42 |
-
def load_model():
|
| 43 |
-
global tokenizer, model, device
|
| 44 |
-
if model is None or tokenizer is None:
|
| 45 |
-
model_name = "priyabrat/AI.or.Human.text.classification"
|
| 46 |
-
logger.info(f"Loading model and tokenizer from {model_name}")
|
| 47 |
-
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir='/app/hf_cache')
|
| 48 |
-
model = AutoModelForSequenceClassification.from_pretrained(model_name, cache_dir='/app/hf_cache')
|
| 49 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 50 |
-
model.to(device).eval()
|
| 51 |
-
logger.info(f"Model loaded on device: {device}")
|
| 52 |
-
else:
|
| 53 |
-
logger.info("Model already loaded.")
|
| 54 |
-
|
| 55 |
-
def classify_line(text):
|
| 56 |
-
with lock, torch.no_grad():
|
| 57 |
-
load_model()
|
| 58 |
-
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
|
| 59 |
-
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 60 |
-
outputs = model(**inputs)
|
| 61 |
-
probs = F.softmax(outputs.logits, dim=-1)
|
| 62 |
-
pred = torch.argmax(probs, dim=-1).item()
|
| 63 |
-
confidence = probs[0][pred].item()
|
| 64 |
-
return {
|
| 65 |
-
"text": text.strip(),
|
| 66 |
-
"label": labels[pred],
|
| 67 |
-
"confidence": round(confidence * 100, 2)
|
| 68 |
-
}
|
| 69 |
-
|
| 70 |
-
def background_worker(user_id, text):
|
| 71 |
-
logger.info(f"Processing started for user_id={user_id}")
|
| 72 |
-
sessions[user_id]['status'] = "processing"
|
| 73 |
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
else:
|
| 78 |
-
lines = [line.strip() for line in text.strip().split('\n') if line.strip()]
|
| 79 |
-
|
| 80 |
-
for i, line in enumerate(lines, 1):
|
| 81 |
-
result = classify_line(line)
|
| 82 |
-
logger.info(f"user_id={user_id} line={i} classified as {result['label']} ({result['confidence']}%)")
|
| 83 |
-
result["line"] = i
|
| 84 |
-
queues[user_id].put(f"data: {json.dumps(result)}\n\n")
|
| 85 |
-
time.sleep(0.1)
|
| 86 |
-
|
| 87 |
-
queues[user_id].put("event: done\ndata: Session complete\n\n")
|
| 88 |
-
except Exception as e:
|
| 89 |
-
logger.error(f"Error processing user_id={user_id}: {e}")
|
| 90 |
-
queues[user_id].put(f"event: error\ndata: {str(e)}\n\n")
|
| 91 |
-
finally:
|
| 92 |
-
sessions[user_id]['status'] = "done"
|
| 93 |
-
logger.info(f"Processing finished for user_id={user_id}")
|
| 94 |
-
time.sleep(1)
|
| 95 |
-
sessions.pop(user_id, None)
|
| 96 |
-
queues.pop(user_id, None)
|
| 97 |
|
| 98 |
-
@app.
|
| 99 |
-
def start_session():
|
| 100 |
-
data = request.
|
| 101 |
user_id = data.get("user_id")
|
| 102 |
text = data.get("text")
|
| 103 |
|
| 104 |
if not user_id or not text:
|
| 105 |
-
|
| 106 |
-
return jsonify({"error": "user_id and text are required"}), 400
|
| 107 |
|
| 108 |
if user_id in sessions:
|
| 109 |
-
|
| 110 |
-
return jsonify({"message": "Session already exists", "status": sessions[user_id]["status"]}), 409
|
| 111 |
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
queues[user_id] = queue.Queue()
|
| 115 |
-
threading.Thread(target=background_worker, args=(user_id, text), daemon=True).start()
|
| 116 |
|
| 117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
|
| 119 |
-
@app.
|
| 120 |
-
def stream(user_id):
|
| 121 |
if user_id not in sessions:
|
| 122 |
-
|
| 123 |
-
return jsonify({"error": "No active session for this user"}), 404
|
| 124 |
|
| 125 |
-
def event_stream():
|
| 126 |
while True:
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
if "event: done" in message or "event: error" in message:
|
| 131 |
-
logger.info(f"Stream ended for user_id={user_id} with message: {message.strip()}")
|
| 132 |
-
break
|
| 133 |
-
except queue.Empty:
|
| 134 |
-
logger.warning(f"Stream timeout for user_id={user_id}")
|
| 135 |
-
yield "event: timeout\ndata: No activity\n\n"
|
| 136 |
break
|
|
|
|
| 137 |
|
| 138 |
-
return
|
| 139 |
|
| 140 |
-
@app.
|
| 141 |
-
def
|
| 142 |
-
status
|
| 143 |
-
logger.info(f"Status request for user_id={user_id}: {status}")
|
| 144 |
-
return jsonify({"status": status})
|
| 145 |
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
|
|
|
| 1 |
+
from fastapi import FastAPI, Request, BackgroundTasks
|
| 2 |
+
from fastapi.responses import StreamingResponse, JSONResponse
|
| 3 |
+
from transformers import pipeline
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
from nltk.tokenize import sent_tokenize
|
| 5 |
+
import asyncio
|
| 6 |
+
import uuid
|
| 7 |
import logging
|
| 8 |
+
import json
|
| 9 |
+
from typing import Dict
|
| 10 |
+
from collections import deque
|
| 11 |
|
| 12 |
+
app = FastAPI()
|
| 13 |
|
| 14 |
+
# Logging
|
| 15 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')
|
|
|
|
|
|
|
|
|
|
| 16 |
logger = logging.getLogger(__name__)
|
| 17 |
|
| 18 |
+
# Load pipeline once at startup
|
| 19 |
+
classifier = pipeline("text-classification", model="priyabrat/AI.or.Human.text.classification")
|
| 20 |
+
sessions: Dict[str, Dict] = {}
|
| 21 |
+
queues: Dict[str, deque] = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
+
@app.get("/health")
|
| 24 |
+
async def health():
|
| 25 |
+
return {"status": "healthy"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
+
@app.post("/start-session")
|
| 28 |
+
async def start_session(request: Request, background_tasks: BackgroundTasks):
|
| 29 |
+
data = await request.json()
|
| 30 |
user_id = data.get("user_id")
|
| 31 |
text = data.get("text")
|
| 32 |
|
| 33 |
if not user_id or not text:
|
| 34 |
+
return JSONResponse(content={"error": "user_id and text are required"}, status_code=400)
|
|
|
|
| 35 |
|
| 36 |
if user_id in sessions:
|
| 37 |
+
return JSONResponse(content={"message": "Session already exists", "status": sessions[user_id]["status"]}, status_code=409)
|
|
|
|
| 38 |
|
| 39 |
+
sessions[user_id] = {"status": "processing"}
|
| 40 |
+
queues[user_id] = deque()
|
|
|
|
|
|
|
| 41 |
|
| 42 |
+
background_tasks.add_task(process_text, user_id, text)
|
| 43 |
+
return {"message": "Session started", "status": "processing"}
|
| 44 |
+
|
| 45 |
+
async def process_text(user_id: str, text: str):
|
| 46 |
+
try:
|
| 47 |
+
lines = sent_tokenize(text) if '\n' not in text else [l.strip() for l in text.strip().split('\n') if l.strip()]
|
| 48 |
+
for idx, line in enumerate(lines, 1):
|
| 49 |
+
result = classifier(line)[0]
|
| 50 |
+
label = result['label']
|
| 51 |
+
confidence = round(result['score'] * 100, 2)
|
| 52 |
+
payload = {
|
| 53 |
+
"line": idx,
|
| 54 |
+
"text": line,
|
| 55 |
+
"label": "AI-generated" if label == "LABEL_0" else "Human-written",
|
| 56 |
+
"confidence": confidence
|
| 57 |
+
}
|
| 58 |
+
queues[user_id].append(f"data: {json.dumps(payload)}\n\n")
|
| 59 |
+
await asyncio.sleep(0.1)
|
| 60 |
+
|
| 61 |
+
queues[user_id].append("event: done\ndata: Session complete\n\n")
|
| 62 |
+
except Exception as e:
|
| 63 |
+
logger.error(f"Error: {e}")
|
| 64 |
+
queues[user_id].append(f"event: error\ndata: {str(e)}\n\n")
|
| 65 |
+
finally:
|
| 66 |
+
sessions[user_id]['status'] = "done"
|
| 67 |
+
await asyncio.sleep(1)
|
| 68 |
+
sessions.pop(user_id, None)
|
| 69 |
+
queues.pop(user_id, None)
|
| 70 |
|
| 71 |
+
@app.get("/stream/{user_id}")
|
| 72 |
+
async def stream(user_id: str):
|
| 73 |
if user_id not in sessions:
|
| 74 |
+
return JSONResponse(content={"error": "No active session"}, status_code=404)
|
|
|
|
| 75 |
|
| 76 |
+
async def event_stream():
|
| 77 |
while True:
|
| 78 |
+
if queues[user_id]:
|
| 79 |
+
yield queues[user_id].popleft()
|
| 80 |
+
elif sessions[user_id]['status'] == 'done':
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
break
|
| 82 |
+
await asyncio.sleep(0.1)
|
| 83 |
|
| 84 |
+
return StreamingResponse(event_stream(), media_type="text/event-stream")
|
| 85 |
|
| 86 |
+
@app.get("/status/{user_id}")
|
| 87 |
+
async def status(user_id: str):
|
| 88 |
+
return {"status": sessions.get(user_id, {}).get("status", "no_session")}
|
|
|
|
|
|
|
| 89 |
|
| 90 |
+
@app.get("/")
|
| 91 |
+
async def index():
|
| 92 |
+
return {"message": "Server is running"}
|