File size: 6,509 Bytes
5c5b473
 
 
 
 
99e1c7c
4bd6a99
99e1c7c
 
 
 
 
5c5b473
 
 
aace552
 
 
 
 
 
3b43dcd
 
aace552
 
99e1c7c
5c5b473
 
 
 
 
 
 
 
 
 
ada3dd4
 
3851cd6
 
 
5c5b473
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ada3dd4
5c5b473
ada3dd4
 
5c5b473
 
 
 
 
 
 
 
ada3dd4
5c5b473
ada3dd4
5c5b473
 
 
 
 
ada3dd4
 
 
 
 
 
 
 
 
5c5b473
ada3dd4
5c5b473
ada3dd4
5c5b473
 
 
 
ada3dd4
 
5c5b473
ada3dd4
 
5c5b473
 
 
 
 
62c3394
 
 
 
 
 
 
 
 
f7eeeee
62c3394
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99e1c7c
5c5b473
 
99e1c7c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7dca781
99e1c7c
7dca781
99e1c7c
62c3394
7dca781
99e1c7c
7dca781
 
 
 
99e1c7c
 
 
 
 
 
 
 
5c5b473
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
from fastapi import FastAPI, Request
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
import os

import json
import requests
from openai import OpenAI
from dotenv import load_dotenv

load_dotenv()

from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles

_openai_client = None

def get_openai_client():
    global _openai_client
    if _openai_client is None:
        api_key = os.getenv("API_KEY") or os.getenv("OPENAI_API_KEY")
        if not api_key:
            raise RuntimeError("No OpenAI API key found")
        _openai_client = OpenAI(api_key=api_key)
    return _openai_client

app = FastAPI(docs_url=None, redoc_url=None)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Tasks that match openenv.yaml exactly
TASKS = [
    {"id": "task_1", "text": "i will kill", "label": "remove"},
    {"id": "task_2", "text": "you are idiot", "label": "flag"},
    {"id": "task_3", "text": "you are lovely", "label": "allow"},
]

current_task_idx = 0

class MyEnvV4Action(BaseModel):
    message: str

class Observation(BaseModel):
    echoed_message: str

class StepResponse(BaseModel):
    observation: Observation
    reward: float
    done: bool

class ResetResponse(BaseModel):
    observation: Observation
    done: bool

@app.post("/reset", response_model=ResetResponse)
async def reset(request: Request):
    global current_task_idx
    current_task_idx = 0
    return ResetResponse(
        observation=Observation(echoed_message=TASKS[current_task_idx]["text"]),
        done=False,
    )

@app.post("/step", response_model=StepResponse)
async def step(request: Request):
    global current_task_idx
    body = {}
    try:
        body = await request.json()
    except Exception:
        pass

    msg = ""
    if "action" in body and isinstance(body["action"], dict) and "message" in body["action"]:
        msg = body["action"]["message"]
    elif "message" in body:
        msg = body["message"]

    true_label = TASKS[current_task_idx]["label"]
    reward = 1.0 if msg.lower().strip() == true_label.lower() else 0.0

    current_task_idx += 1
    done = current_task_idx >= len(TASKS)

    next_text = TASKS[current_task_idx]["text"] if not done else ""

    return StepResponse(
        observation=Observation(echoed_message=next_text),
        reward=reward,
        done=done,
    )

@app.get("/state")
async def state():
    done = current_task_idx >= len(TASKS)
    next_text = TASKS[current_task_idx]["text"] if not done else ""
    return {
        "observation": {"echoed_message": next_text},
        "done": done
    }

class ModerationRequest(BaseModel):
    text: str

from groq import Groq

def groq_moderate(text: str, hf_scores: dict) -> dict:
    client = Groq(api_key=os.getenv("GROQ_API_KEY"))

    relevant_keys = ["toxicity", "severe_toxicity", "insult", "threat", "obscene", "identity_attack"]
    filtered_scores = {k: round(hf_scores.get(k, 0.0), 3) for k in relevant_keys if k in hf_scores}

    response = client.chat.completions.create(
        model="llama-3.3-70b-versatile",
        messages=[
            {"role": "system", "content": """You are an expert content moderation AI.

You will receive text and toxicity scores (0.0-1.0) from a RoBERTa model.

Make a decision based on FULL CONTEXT and INTENT — not just keywords. Consider:
- Sarcasm or dark humour that looks toxic but isn't harmful
- Context that changes meaning ("I'll destroy you at chess" is fine)
- Whether content genuinely targets a person harmfully
- Mild insults like "idiot" or "stupid" should be FLAG not REMOVE

Respond ONLY with valid JSON, no markdown:
{"decision": "allow" or "flag" or "remove", "confidence": <0.0-1.0>, "explanation": "<1 sentence>"}

allow  = safe content
flag   = mildly toxic, rude, or ambiguous
remove = genuine hate speech, real threats, severe harassment"""},
            {"role": "user", "content": f'Text: "{text}"\nScores: {json.dumps(filtered_scores)}\nModerate this.'}
        ],
        temperature=0.1,
        max_tokens=100,
    )

    raw = response.choices[0].message.content.strip()

    if raw.startswith("```"):
        raw = raw.split("```")[1]
        if raw.startswith("json"):
            raw = raw[4:]
    raw = raw.strip()

    result = json.loads(raw)
    result["decision"] = result.get("decision", "flag").lower()
    if result["decision"] not in ("allow", "flag", "remove"):
        result["decision"] = "flag"
    result["confidence"] = min(max(float(result.get("confidence", 0.5)), 0.0), 1.0)
    result["explanation"] = result.get("explanation", "No explanation provided.")
    return result

@app.post("/moderate")
def moderate(request: ModerationRequest):
    text = request.text.strip()
    
    # Fast skip validation
    if not text:
        return {
            "decision": "allow",
            "confidence": 1.0,
            "explanation": "Empty input provides no context for moderation.",
            "ai_scores": {
                "toxicity": 0.0,
                "insult": 0.0,
                "threat": 0.0,
                "obscene": 0.0
            }
        }
    
    # Stage 1: Lazy load and classify using HuggingFace RoBERTa 
    try:
        from app.models.toxicity_model import predict_toxicity
        hf_scores = predict_toxicity(text)
    except Exception as e:
        hf_scores = {}
        
    llm_result = groq_moderate(text, hf_scores)
    
    ai_scores = {
        "toxicity": round(hf_scores.get("toxicity", 0.0), 3),
        "insult":   round(hf_scores.get("insult", 0.0), 3),
        "threat":   round(hf_scores.get("threat", 0.0), 3),
        "obscene":  round(hf_scores.get("obscene", 0.0), 3),
    }
    
    return {
        "decision": llm_result["decision"],
        "confidence": llm_result["confidence"],
        "explanation": llm_result["explanation"],
        "ai_scores": ai_scores
    }

BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
FRONTEND_DIR = os.path.join(BASE_DIR, "app", "frontend")

def main():
    import uvicorn
    uvicorn.run("server.app:app", host="0.0.0.0", port=7860)

try:
    app.mount("/static", StaticFiles(directory=FRONTEND_DIR), name="static")
except:
    pass

@app.get("/")
def serve_ui():
    path = os.path.join(FRONTEND_DIR, "index.html")
    if os.path.exists(path):
        return FileResponse(path)
    return {"status": "ok"}

if __name__ == "__main__":
    main()