File size: 2,886 Bytes
d17c2ae
 
28ff626
5bcd567
d17c2ae
28ff626
d17c2ae
 
 
 
15b499b
28ff626
d17c2ae
9fde7ed
 
28ff626
d17c2ae
 
 
 
 
 
ae41e6f
28ff626
d17c2ae
28ff626
15b499b
28ff626
d17c2ae
28ff626
15b499b
 
 
 
d17c2ae
15b499b
 
d17c2ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15b499b
28ff626
 
15b499b
d17c2ae
15b499b
d17c2ae
15b499b
d17c2ae
 
 
 
 
 
 
15b499b
 
28ff626
5bcd567
28ff626
d17c2ae
28ff626
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, StreamingResponse, PlainTextResponse
from transformers import pipeline
from nltk.tokenize import sent_tokenize
import time
import json
import threading
import queue
import logging
import os

app = FastAPI()
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

classifier = pipeline("text-classification", model="priyabrat/AI.or.Human.text.classification")
sessions = {}
queues = {}

@app.get("/")
async def index():
    return PlainTextResponse("✅ FastAPI server running on Hugging Face Spaces!")

@app.get("/health")
async def health_check():
    return {"status": "healthy"}

@app.post("/start-session")
async def start_session(request: Request):
    data = await request.json()
    user_id = data.get("user_id")
    text = data.get("text")

    if not user_id or not text:
        return JSONResponse({"error": "user_id and text required"}, status_code=400)

    if user_id in sessions:
        return JSONResponse({"message": "Session exists", "status": sessions[user_id]["status"]}, status_code=409)

    sessions[user_id] = {"status": "pending"}
    queues[user_id] = queue.Queue()

    def worker():
        try:
            sessions[user_id]["status"] = "processing"
            lines = sent_tokenize(text) if '\n' not in text else [l.strip() for l in text.split('\n') if l.strip()]
            for i, line in enumerate(lines, 1):
                result = classifier(line)[0]
                queues[user_id].put(f"data: {json.dumps({'line': i, 'text': line, 'label': result['label'], 'confidence': round(result['score']*100,2)})}\n\n")
                time.sleep(0.1)
            queues[user_id].put("event: done\ndata: Session complete\n\n")
        except Exception as e:
            queues[user_id].put(f"event: error\ndata: {str(e)}\n\n")
        finally:
            sessions[user_id]["status"] = "done"
            time.sleep(1)
            del sessions[user_id]
            del queues[user_id]

    threading.Thread(target=worker, daemon=True).start()
    return {"message": "Session started", "status": "pending"}

@app.get("/stream/{user_id}")
async def stream(user_id: str):
    if user_id not in sessions:
        return JSONResponse({"error": "No active session"}, status_code=404)

    def event_stream():
        while True:
            try:
                msg = queues[user_id].get(timeout=30)
                yield msg
                if "event: done" in msg or "event: error" in msg:
                    break
            except queue.Empty:
                yield "event: timeout\ndata: No activity\n\n"
                break

    return StreamingResponse(event_stream(), media_type="text/event-stream")

@app.get("/status/{user_id}")
async def session_status(user_id: str):
    return {"status": sessions.get(user_id, {}).get("status", "no_session")}