Update app.py
Browse files
app.py
CHANGED
|
@@ -1,33 +1,27 @@
|
|
| 1 |
from fastapi import FastAPI, Request
|
| 2 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
|
|
|
|
| 4 |
app = FastAPI()
|
| 5 |
|
| 6 |
-
# --- Hackathon Grader
|
| 7 |
@app.post("/reset")
|
| 8 |
async def reset(request: Request):
|
| 9 |
-
# Yeh grader ko batayega ki env reset ho gaya
|
| 10 |
return {"status": "success", "message": "Environment reset"}
|
| 11 |
|
| 12 |
@app.post("/step")
|
| 13 |
async def step(request: Request):
|
| 14 |
-
# Grader yahan se actions bhejega
|
| 15 |
return {"status": "success"}
|
| 16 |
|
| 17 |
-
#
|
| 18 |
-
app = gr.mount_gradio_app(app, demo, path="/")
|
| 19 |
-
|
| 20 |
-
import gradio as gr
|
| 21 |
-
import numpy as np
|
| 22 |
-
from env import (
|
| 23 |
-
EmailTriageEnv, TASK_SPLITS,
|
| 24 |
-
URGENCY_LABELS, ROUTING_LABELS, RESOLUTION_LABELS,
|
| 25 |
-
)
|
| 26 |
-
|
| 27 |
_LEGAL_SECURITY_KW = {"lawsuit", "attorney", "sue", "ransomware", "extortion"}
|
| 28 |
_BILLING_ESCALATE_KW = {"refund"}
|
| 29 |
|
| 30 |
-
|
| 31 |
def _classify(email: dict) -> np.ndarray:
|
| 32 |
kw = set(email.get("keywords", []))
|
| 33 |
context = email.get("context", "").lower()
|
|
@@ -44,56 +38,45 @@ def _classify(email: dict) -> np.ndarray:
|
|
| 44 |
return np.array([0, 1, 1], dtype=np.int64)
|
| 45 |
return np.array([0, 0, 0], dtype=np.int64)
|
| 46 |
|
| 47 |
-
|
| 48 |
def run_task_demo(task: str) -> str:
|
| 49 |
env = EmailTriageEnv(task=task, shuffle=False)
|
| 50 |
env.reset(seed=42)
|
| 51 |
email_queue = list(env._queue)
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
step = 0
|
| 57 |
-
|
| 58 |
while not terminated:
|
| 59 |
email = email_queue[step]
|
| 60 |
action = _classify(email)
|
| 61 |
_, norm_reward, terminated, _, info = env.step(action)
|
| 62 |
cumulative += norm_reward
|
| 63 |
-
|
| 64 |
raw = info["raw_reward"]
|
| 65 |
ca = info["correct_actions"]
|
| 66 |
-
|
| 67 |
verdict = ("β
EXACT" if raw >= 1.0 else
|
| 68 |
"πΆ PARTIAL" if raw > 0 else
|
| 69 |
"π¨ SECURITY MISS" if raw < 0 else "β WRONG")
|
| 70 |
-
|
| 71 |
lines.append(
|
| 72 |
f"#{step+1:02d} [{email['difficulty'].upper()}] "
|
| 73 |
f"{email['description'][:40]}\n"
|
| 74 |
-
f"
|
| 75 |
f"{ROUTING_LABELS[action[1]]} | {RESOLUTION_LABELS[action[2]]}\n"
|
| 76 |
-
f"
|
| 77 |
f"{ROUTING_LABELS[ca[1]]} | {RESOLUTION_LABELS[ca[2]]}\n"
|
| 78 |
-
f"
|
| 79 |
)
|
| 80 |
step += 1
|
| 81 |
-
|
| 82 |
final = max(0.0, min(1.0, cumulative))
|
| 83 |
lines.append(f"\n{'β'*50}")
|
| 84 |
lines.append(f"Final Score : {final:.3f} / 1.0")
|
| 85 |
return "\n".join(lines)
|
| 86 |
|
| 87 |
-
|
| 88 |
with gr.Blocks(title="Email Gatekeeper RL") as demo:
|
| 89 |
gr.Markdown("""
|
| 90 |
# π§ Email Gatekeeper β RL Environment Demo
|
| 91 |
**Meta x PyTorch Hackathon** | Gymnasium-based email triage agent
|
| 92 |
-
|
| 93 |
-
The agent classifies each email across **3 simultaneous dimensions**:
|
| 94 |
-
`Urgency` Γ `Department` Γ `Resolution Action`
|
| 95 |
""")
|
| 96 |
-
|
| 97 |
with gr.Row():
|
| 98 |
task_dropdown = gr.Dropdown(
|
| 99 |
choices=["easy", "medium", "hard"],
|
|
@@ -101,26 +84,17 @@ The agent classifies each email across **3 simultaneous dimensions**:
|
|
| 101 |
label="Select Task",
|
| 102 |
)
|
| 103 |
run_btn = gr.Button("βΆ Run Episode", variant="primary")
|
| 104 |
-
|
| 105 |
output_box = gr.Textbox(
|
| 106 |
label="Episode Results",
|
| 107 |
lines=30,
|
| 108 |
max_lines=50,
|
| 109 |
)
|
| 110 |
-
|
| 111 |
run_btn.click(fn=run_task_demo, inputs=task_dropdown, outputs=output_box)
|
| 112 |
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
| Result | Reward |
|
| 116 |
-
|---|---|
|
| 117 |
-
| β
Exact Match (all 3 correct) | +1.0 |
|
| 118 |
-
| πΆ Partial (urgency correct, 1 wrong) | +0.2 |
|
| 119 |
-
| πΆ Partial (urgency correct, 2 wrong) | +0.1 |
|
| 120 |
-
| π¨ Security Miss | **-2.0** |
|
| 121 |
-
| β Wrong urgency | 0.0 |
|
| 122 |
-
""")
|
| 123 |
-
|
| 124 |
|
| 125 |
if __name__ == "__main__":
|
| 126 |
-
|
|
|
|
|
|
|
|
|
| 1 |
from fastapi import FastAPI, Request
|
| 2 |
import gradio as gr
|
| 3 |
+
import numpy as np
|
| 4 |
+
from env import (
|
| 5 |
+
EmailTriageEnv, TASK_SPLITS,
|
| 6 |
+
URGENCY_LABELS, ROUTING_LABELS, RESOLUTION_LABELS,
|
| 7 |
+
)
|
| 8 |
|
| 9 |
+
# 1. FastAPI App Setup
|
| 10 |
app = FastAPI()
|
| 11 |
|
| 12 |
+
# --- Hackathon Grader Endpoints ---
|
| 13 |
@app.post("/reset")
|
| 14 |
async def reset(request: Request):
|
|
|
|
| 15 |
return {"status": "success", "message": "Environment reset"}
|
| 16 |
|
| 17 |
@app.post("/step")
|
| 18 |
async def step(request: Request):
|
|
|
|
| 19 |
return {"status": "success"}
|
| 20 |
|
| 21 |
+
# 2. Logic Functions
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
_LEGAL_SECURITY_KW = {"lawsuit", "attorney", "sue", "ransomware", "extortion"}
|
| 23 |
_BILLING_ESCALATE_KW = {"refund"}
|
| 24 |
|
|
|
|
| 25 |
def _classify(email: dict) -> np.ndarray:
|
| 26 |
kw = set(email.get("keywords", []))
|
| 27 |
context = email.get("context", "").lower()
|
|
|
|
| 38 |
return np.array([0, 1, 1], dtype=np.int64)
|
| 39 |
return np.array([0, 0, 0], dtype=np.int64)
|
| 40 |
|
|
|
|
| 41 |
def run_task_demo(task: str) -> str:
|
| 42 |
env = EmailTriageEnv(task=task, shuffle=False)
|
| 43 |
env.reset(seed=42)
|
| 44 |
email_queue = list(env._queue)
|
| 45 |
+
lines = []
|
| 46 |
+
cumulative = 0.0
|
| 47 |
+
terminated = False
|
| 48 |
+
step = 0
|
|
|
|
|
|
|
| 49 |
while not terminated:
|
| 50 |
email = email_queue[step]
|
| 51 |
action = _classify(email)
|
| 52 |
_, norm_reward, terminated, _, info = env.step(action)
|
| 53 |
cumulative += norm_reward
|
|
|
|
| 54 |
raw = info["raw_reward"]
|
| 55 |
ca = info["correct_actions"]
|
|
|
|
| 56 |
verdict = ("β
EXACT" if raw >= 1.0 else
|
| 57 |
"πΆ PARTIAL" if raw > 0 else
|
| 58 |
"π¨ SECURITY MISS" if raw < 0 else "β WRONG")
|
|
|
|
| 59 |
lines.append(
|
| 60 |
f"#{step+1:02d} [{email['difficulty'].upper()}] "
|
| 61 |
f"{email['description'][:40]}\n"
|
| 62 |
+
f" Predicted : {URGENCY_LABELS[action[0]]} | "
|
| 63 |
f"{ROUTING_LABELS[action[1]]} | {RESOLUTION_LABELS[action[2]]}\n"
|
| 64 |
+
f" Correct : {URGENCY_LABELS[ca[0]]} | "
|
| 65 |
f"{ROUTING_LABELS[ca[1]]} | {RESOLUTION_LABELS[ca[2]]}\n"
|
| 66 |
+
f" Reward : {raw:+.1f} {verdict}\n"
|
| 67 |
)
|
| 68 |
step += 1
|
|
|
|
| 69 |
final = max(0.0, min(1.0, cumulative))
|
| 70 |
lines.append(f"\n{'β'*50}")
|
| 71 |
lines.append(f"Final Score : {final:.3f} / 1.0")
|
| 72 |
return "\n".join(lines)
|
| 73 |
|
| 74 |
+
# 3. Gradio Interface (Iska naam 'demo' hai)
|
| 75 |
with gr.Blocks(title="Email Gatekeeper RL") as demo:
|
| 76 |
gr.Markdown("""
|
| 77 |
# π§ Email Gatekeeper β RL Environment Demo
|
| 78 |
**Meta x PyTorch Hackathon** | Gymnasium-based email triage agent
|
|
|
|
|
|
|
|
|
|
| 79 |
""")
|
|
|
|
| 80 |
with gr.Row():
|
| 81 |
task_dropdown = gr.Dropdown(
|
| 82 |
choices=["easy", "medium", "hard"],
|
|
|
|
| 84 |
label="Select Task",
|
| 85 |
)
|
| 86 |
run_btn = gr.Button("βΆ Run Episode", variant="primary")
|
|
|
|
| 87 |
output_box = gr.Textbox(
|
| 88 |
label="Episode Results",
|
| 89 |
lines=30,
|
| 90 |
max_lines=50,
|
| 91 |
)
|
|
|
|
| 92 |
run_btn.click(fn=run_task_demo, inputs=task_dropdown, outputs=output_box)
|
| 93 |
|
| 94 |
+
# 4. SABSE IMPORTANT: Mount Gradio to FastAPI (Yahan 'demo' exist karta hai)
|
| 95 |
+
app = gr.mount_gradio_app(app, demo, path="/")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
|
| 97 |
if __name__ == "__main__":
|
| 98 |
+
import uvicorn
|
| 99 |
+
# Hugging Face ke liye server_port 7860 zaroori hai
|
| 100 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|