|
|
from fastapi import FastAPI, Request |
|
|
from fastapi.responses import JSONResponse, StreamingResponse, PlainTextResponse |
|
|
from transformers import pipeline |
|
|
from nltk.tokenize import sent_tokenize |
|
|
import time |
|
|
import json |
|
|
import threading |
|
|
import queue |
|
|
import logging |
|
|
import os |
|
|
|
|
|
app = FastAPI() |
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
classifier = pipeline("text-classification", model="priyabrat/AI.or.Human.text.classification") |
|
|
sessions = {} |
|
|
queues = {} |
|
|
|
|
|
@app.get("/") |
|
|
async def index(): |
|
|
return PlainTextResponse("✅ FastAPI server running on Hugging Face Spaces!") |
|
|
|
|
|
@app.get("/health") |
|
|
async def health_check(): |
|
|
return {"status": "healthy"} |
|
|
|
|
|
@app.post("/start-session") |
|
|
async def start_session(request: Request): |
|
|
data = await request.json() |
|
|
user_id = data.get("user_id") |
|
|
text = data.get("text") |
|
|
|
|
|
if not user_id or not text: |
|
|
return JSONResponse({"error": "user_id and text required"}, status_code=400) |
|
|
|
|
|
if user_id in sessions: |
|
|
return JSONResponse({"message": "Session exists", "status": sessions[user_id]["status"]}, status_code=409) |
|
|
|
|
|
sessions[user_id] = {"status": "pending"} |
|
|
queues[user_id] = queue.Queue() |
|
|
|
|
|
def worker(): |
|
|
try: |
|
|
sessions[user_id]["status"] = "processing" |
|
|
lines = sent_tokenize(text) if '\n' not in text else [l.strip() for l in text.split('\n') if l.strip()] |
|
|
for i, line in enumerate(lines, 1): |
|
|
result = classifier(line)[0] |
|
|
queues[user_id].put(f"data: {json.dumps({'line': i, 'text': line, 'label': result['label'], 'confidence': round(result['score']*100,2)})}\n\n") |
|
|
time.sleep(0.1) |
|
|
queues[user_id].put("event: done\ndata: Session complete\n\n") |
|
|
except Exception as e: |
|
|
queues[user_id].put(f"event: error\ndata: {str(e)}\n\n") |
|
|
finally: |
|
|
sessions[user_id]["status"] = "done" |
|
|
time.sleep(1) |
|
|
del sessions[user_id] |
|
|
del queues[user_id] |
|
|
|
|
|
threading.Thread(target=worker, daemon=True).start() |
|
|
return {"message": "Session started", "status": "pending"} |
|
|
|
|
|
@app.get("/stream/{user_id}") |
|
|
async def stream(user_id: str): |
|
|
if user_id not in sessions: |
|
|
return JSONResponse({"error": "No active session"}, status_code=404) |
|
|
|
|
|
def event_stream(): |
|
|
while True: |
|
|
try: |
|
|
msg = queues[user_id].get(timeout=30) |
|
|
yield msg |
|
|
if "event: done" in msg or "event: error" in msg: |
|
|
break |
|
|
except queue.Empty: |
|
|
yield "event: timeout\ndata: No activity\n\n" |
|
|
break |
|
|
|
|
|
return StreamingResponse(event_stream(), media_type="text/event-stream") |
|
|
|
|
|
@app.get("/status/{user_id}") |
|
|
async def session_status(user_id: str): |
|
|
return {"status": sessions.get(user_id, {}).get("status", "no_session")} |
|
|
|