Spaces:
Sleeping
Sleeping
Vighnesh commited on
Commit ·
5d570d6
1
Parent(s): 2e81e98
result after no sleep
Browse files- get_baseline.py +65 -0
- make_chart.py +117 -0
- plot_results.py +266 -0
- train_grpo_safe.ipynb +562 -0
get_baseline.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, sys
|
| 2 |
+
sys.path.insert(0, r'C:\Users\Admin\OneDrive\Desktop\OpenEnv Hacathon\support_ticket_env')
|
| 3 |
+
|
| 4 |
+
from support_ticket_env.server.support_environment import SupportTicketEnvironment
|
| 5 |
+
from support_ticket_env.models import SupportAction
|
| 6 |
+
|
| 7 |
+
CATEGORY_KEYWORDS = {
|
| 8 |
+
"billing": ["charge", "invoice", "payment", "bill", "refund", "subscription", "price", "cost", "fee", "money"],
|
| 9 |
+
"technical": ["error", "bug", "crash", "not working", "broken", "issue", "problem", "fail", "500", "api"],
|
| 10 |
+
"account": ["login", "password", "account", "access", "sign in", "email", "username", "cancel"],
|
| 11 |
+
"refund": ["refund", "return", "money back", "reimburse", "cancel order"],
|
| 12 |
+
"general": ["hours", "contact", "phone", "help", "question", "info", "support"],
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
def rule_based(obs):
|
| 16 |
+
text = obs.ticket_text.lower()
|
| 17 |
+
if not obs.current_category:
|
| 18 |
+
best_cat, best_score = "general", 0
|
| 19 |
+
for cat, keywords in CATEGORY_KEYWORDS.items():
|
| 20 |
+
score = sum(1 for kw in keywords if kw in text)
|
| 21 |
+
if score > best_score:
|
| 22 |
+
best_score = score
|
| 23 |
+
best_cat = cat
|
| 24 |
+
return {"action_type": "classify", "category": best_cat}
|
| 25 |
+
cat = obs.current_category
|
| 26 |
+
if cat == "technical":
|
| 27 |
+
return {"action_type": "escalate", "reason": "needs engineering"}
|
| 28 |
+
elif cat == "general":
|
| 29 |
+
return {"action_type": "close", "reason": "resolved"}
|
| 30 |
+
else:
|
| 31 |
+
return {"action_type": "reply", "reply_text": f"Thank you for contacting us about your {cat} issue."}
|
| 32 |
+
|
| 33 |
+
SEEDS = [42, 7, 123]
|
| 34 |
+
MAX_STEPS = 10
|
| 35 |
+
results = {}
|
| 36 |
+
|
| 37 |
+
for task_id in [1, 2, 3]:
|
| 38 |
+
scores = []
|
| 39 |
+
for seed in SEEDS:
|
| 40 |
+
env = SupportTicketEnvironment()
|
| 41 |
+
obs = env.reset(task_id=task_id, seed=seed)
|
| 42 |
+
rewards = []
|
| 43 |
+
for _ in range(MAX_STEPS):
|
| 44 |
+
if obs.done:
|
| 45 |
+
break
|
| 46 |
+
action_dict = rule_based(obs)
|
| 47 |
+
try:
|
| 48 |
+
action = SupportAction(**action_dict)
|
| 49 |
+
obs = env.step(action)
|
| 50 |
+
rewards.append(obs.reward or 0.0)
|
| 51 |
+
except:
|
| 52 |
+
rewards.append(0.0)
|
| 53 |
+
if obs.done:
|
| 54 |
+
break
|
| 55 |
+
score = round(min(max(sum(rewards) / MAX_STEPS, 0.0), 1.0), 3)
|
| 56 |
+
scores.append(score)
|
| 57 |
+
print(f" Task {task_id} seed={seed}: {score:.3f}")
|
| 58 |
+
avg = round(sum(scores) / len(scores), 3)
|
| 59 |
+
results["task" + str(task_id)] = avg
|
| 60 |
+
print(f" Task {task_id} avg: {avg:.3f}")
|
| 61 |
+
|
| 62 |
+
overall = round(sum(results.values()) / 3, 3)
|
| 63 |
+
results["overall"] = overall
|
| 64 |
+
print(f"Overall rule-based avg: {overall:.3f}")
|
| 65 |
+
print("Rule-based scores:", results)
|
make_chart.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
make_chart.py
|
| 3 |
+
Generates the before/after reward chart using known scores.
|
| 4 |
+
Run: python make_chart.py
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import matplotlib
|
| 8 |
+
matplotlib.use("Agg")
|
| 9 |
+
import matplotlib.pyplot as plt
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
+
# Rule-based agent (no LLM, no training) — measured locally
|
| 13 |
+
baseline_scores = {
|
| 14 |
+
"task1": 0.100,
|
| 15 |
+
"task2": 0.113,
|
| 16 |
+
"task3": 0.218,
|
| 17 |
+
"overall": 0.144,
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
# Qwen2.5-72B via HF Inference API — from your clean run logs
|
| 21 |
+
llm_scores = {
|
| 22 |
+
"task1": 0.100,
|
| 23 |
+
"task2": 0.113,
|
| 24 |
+
"task3": 0.262,
|
| 25 |
+
"overall": 0.158,
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
# After GRPO training — update these once Colab finishes
|
| 29 |
+
# If Colab not done yet, use llm_scores as placeholder
|
| 30 |
+
grpo_scores = {
|
| 31 |
+
"task1": 0.100,
|
| 32 |
+
"task2": 0.113,
|
| 33 |
+
"task3": 0.262,
|
| 34 |
+
"overall": 0.158,
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
def make_chart(baseline, llm, grpo, output="reward_chart.png"):
|
| 38 |
+
tasks = ["Task 1\n(Classify)", "Task 2\n(Action)", "Task 3\n(Full Resolve)", "Overall"]
|
| 39 |
+
keys = ["task1", "task2", "task3", "overall"]
|
| 40 |
+
|
| 41 |
+
b_vals = [baseline.get(k, 0) for k in keys]
|
| 42 |
+
llm_vals = [llm.get(k, 0) for k in keys]
|
| 43 |
+
grpo_vals = [grpo.get(k, 0) for k in keys]
|
| 44 |
+
|
| 45 |
+
x = np.arange(len(tasks))
|
| 46 |
+
width = 0.25
|
| 47 |
+
|
| 48 |
+
fig, axes = plt.subplots(1, 2, figsize=(15, 6))
|
| 49 |
+
fig.patch.set_facecolor("#1a1a2e")
|
| 50 |
+
for ax in axes:
|
| 51 |
+
ax.set_facecolor("#16213e")
|
| 52 |
+
|
| 53 |
+
ax1 = axes[0]
|
| 54 |
+
bars1 = ax1.bar(x - width, b_vals, width, label="Rule-Based", color="#636e72", edgecolor="#2d3436")
|
| 55 |
+
bars2 = ax1.bar(x, llm_vals, width, label="Qwen2.5-72B", color="#0984e3", edgecolor="#2d3436")
|
| 56 |
+
bars3 = ax1.bar(x + width, grpo_vals, width, label="After GRPO", color="#00b894", edgecolor="#2d3436")
|
| 57 |
+
|
| 58 |
+
for bars in [bars1, bars2, bars3]:
|
| 59 |
+
for bar in bars:
|
| 60 |
+
h = bar.get_height()
|
| 61 |
+
ax1.text(bar.get_x() + bar.get_width()/2., h + 0.008,
|
| 62 |
+
f"{h:.2f}", ha="center", va="bottom", fontsize=8.5, color="white")
|
| 63 |
+
|
| 64 |
+
ax1.set_xticks(x)
|
| 65 |
+
ax1.set_xticklabels(tasks, color="white", fontsize=10)
|
| 66 |
+
ax1.set_ylabel("Score (0 - 1)", color="white", fontsize=11)
|
| 67 |
+
ax1.set_title("Score Comparison Across Training Stages", color="white", fontsize=12, fontweight="bold", pad=10)
|
| 68 |
+
ax1.set_ylim(0, 1.2)
|
| 69 |
+
ax1.tick_params(colors="white")
|
| 70 |
+
ax1.spines[:].set_color("#2d3436")
|
| 71 |
+
ax1.yaxis.grid(True, alpha=0.2, color="white")
|
| 72 |
+
ax1.set_axisbelow(True)
|
| 73 |
+
ax1.legend(facecolor="#0f3460", edgecolor="#2d3436", labelcolor="white", fontsize=9)
|
| 74 |
+
|
| 75 |
+
ax2 = axes[1]
|
| 76 |
+
deltas = [round(grpo.get(k, 0) - baseline.get(k, 0), 3) for k in keys]
|
| 77 |
+
colors = ["#00b894" if d >= 0 else "#d63031" for d in deltas]
|
| 78 |
+
bars4 = ax2.bar(x, deltas, width=0.4, color=colors, edgecolor="#2d3436")
|
| 79 |
+
|
| 80 |
+
for bar, d in zip(bars4, deltas):
|
| 81 |
+
ypos = bar.get_height() + 0.004 if d >= 0 else bar.get_height() - 0.016
|
| 82 |
+
ax2.text(bar.get_x() + bar.get_width()/2., ypos,
|
| 83 |
+
f"{d:+.3f}", ha="center", va="bottom", fontsize=11,
|
| 84 |
+
fontweight="bold", color="white")
|
| 85 |
+
|
| 86 |
+
ax2.axhline(0, color="white", linewidth=0.8, alpha=0.4)
|
| 87 |
+
ax2.set_xticks(x)
|
| 88 |
+
ax2.set_xticklabels(tasks, color="white", fontsize=10)
|
| 89 |
+
ax2.set_ylabel("Score Delta (GRPO vs Rule-Based)", color="white", fontsize=10)
|
| 90 |
+
ax2.set_title("Improvement: Rule-Based → After GRPO", color="white", fontsize=12, fontweight="bold", pad=10)
|
| 91 |
+
ax2.tick_params(colors="white")
|
| 92 |
+
ax2.spines[:].set_color("#2d3436")
|
| 93 |
+
ax2.yaxis.grid(True, alpha=0.2, color="white")
|
| 94 |
+
ax2.set_axisbelow(True)
|
| 95 |
+
|
| 96 |
+
fig.suptitle(
|
| 97 |
+
"Support Ticket Env — Training Results\nModel: Qwen2.5-0.5B-Instruct + GRPO | OpenEnv x Scalar Hackathon 2026",
|
| 98 |
+
color="white", fontsize=11, y=1.02
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
plt.tight_layout()
|
| 102 |
+
plt.savefig(output, dpi=180, bbox_inches="tight", facecolor=fig.get_facecolor())
|
| 103 |
+
print(f"Chart saved: {output}")
|
| 104 |
+
|
| 105 |
+
print("\n" + "="*52)
|
| 106 |
+
print(f"{'Task':<14} {'Rule-Based':>10} {'Qwen-72B':>10} {'GRPO':>8} {'Delta':>8}")
|
| 107 |
+
print("-"*52)
|
| 108 |
+
for k, label in [("task1","Task 1"),("task2","Task 2"),("task3","Task 3"),("overall","Overall")]:
|
| 109 |
+
b = baseline.get(k, 0)
|
| 110 |
+
l = llm.get(k, 0)
|
| 111 |
+
g = grpo.get(k, 0)
|
| 112 |
+
d = g - b
|
| 113 |
+
print(f"{label:<14} {b:>10.3f} {l:>10.3f} {g:>8.3f} {d:>+8.3f}")
|
| 114 |
+
print("="*52)
|
| 115 |
+
|
| 116 |
+
if __name__ == "__main__":
|
| 117 |
+
make_chart(baseline_scores, llm_scores, grpo_scores)
|
plot_results.py
ADDED
|
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
plot_results.py
|
| 3 |
+
Run inference across 3 seeds for all tasks and plot before/after bar chart.
|
| 4 |
+
Usage:
|
| 5 |
+
set HF_TOKEN=hf_...
|
| 6 |
+
set API_BASE_URL=https://router.huggingface.co/v1
|
| 7 |
+
set MODEL_NAME=Qwen/Qwen2.5-72B-Instruct
|
| 8 |
+
python plot_results.py
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import os
|
| 12 |
+
import sys
|
| 13 |
+
import json
|
| 14 |
+
import re
|
| 15 |
+
import random
|
| 16 |
+
import matplotlib
|
| 17 |
+
matplotlib.use("Agg")
|
| 18 |
+
import matplotlib.pyplot as plt
|
| 19 |
+
import matplotlib.patches as mpatches
|
| 20 |
+
import numpy as np
|
| 21 |
+
|
| 22 |
+
ROOT = os.path.dirname(os.path.abspath(__file__))
|
| 23 |
+
sys.path.insert(0, ROOT)
|
| 24 |
+
|
| 25 |
+
from openai import OpenAI
|
| 26 |
+
from support_ticket_env.server.support_environment import SupportTicketEnvironment
|
| 27 |
+
from support_ticket_env.models import SupportAction
|
| 28 |
+
|
| 29 |
+
API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
|
| 30 |
+
API_BASE_URL = os.getenv("API_BASE_URL") or "https://router.huggingface.co/v1"
|
| 31 |
+
MODEL_NAME = os.getenv("MODEL_NAME") or "Qwen/Qwen2.5-72B-Instruct"
|
| 32 |
+
MAX_STEPS = 10
|
| 33 |
+
SEEDS = [42, 7, 123]
|
| 34 |
+
|
| 35 |
+
VALID_CATEGORIES = ["billing", "technical", "account", "general", "refund"]
|
| 36 |
+
VALID_ACTIONS = ["classify", "reply", "escalate", "close"]
|
| 37 |
+
|
| 38 |
+
SYSTEM_PROMPT = """You are a customer support AI agent handling tickets.
|
| 39 |
+
Respond ONLY with a JSON object:
|
| 40 |
+
{
|
| 41 |
+
"action_type": "classify" | "reply" | "escalate" | "close",
|
| 42 |
+
"category": "billing" | "technical" | "account" | "general" | "refund",
|
| 43 |
+
"reply_text": "...",
|
| 44 |
+
"reason": "..."
|
| 45 |
+
}
|
| 46 |
+
Rules:
|
| 47 |
+
- Task 1: action_type=classify, pick correct category
|
| 48 |
+
- Task 2: first classify, then reply/escalate/close
|
| 49 |
+
- Task 3: classify each ticket then resolve it
|
| 50 |
+
- category only needed for classify
|
| 51 |
+
- reply_text only needed for reply
|
| 52 |
+
- technical issues: escalate
|
| 53 |
+
- resolved issues: close
|
| 54 |
+
- billing/account/refund: reply"""
|
| 55 |
+
|
| 56 |
+
CATEGORY_KEYWORDS = {
|
| 57 |
+
"billing": ["charge", "invoice", "payment", "bill", "refund", "subscription", "price", "cost", "fee", "money"],
|
| 58 |
+
"technical": ["error", "bug", "crash", "not working", "broken", "issue", "problem", "fail", "500", "api"],
|
| 59 |
+
"account": ["login", "password", "account", "access", "sign in", "email", "username", "cancel"],
|
| 60 |
+
"refund": ["refund", "return", "money back", "reimburse", "cancel order"],
|
| 61 |
+
"general": ["hours", "contact", "phone", "help", "question", "info", "support"],
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
def rule_based_action(obs):
|
| 65 |
+
text = obs.ticket_text.lower()
|
| 66 |
+
if not obs.current_category:
|
| 67 |
+
best_cat, best_score = "general", 0
|
| 68 |
+
for cat, keywords in CATEGORY_KEYWORDS.items():
|
| 69 |
+
score = sum(1 for kw in keywords if kw in text)
|
| 70 |
+
if score > best_score:
|
| 71 |
+
best_score = score
|
| 72 |
+
best_cat = cat
|
| 73 |
+
return {"action_type": "classify", "category": best_cat}
|
| 74 |
+
cat = obs.current_category
|
| 75 |
+
if cat == "technical":
|
| 76 |
+
return {"action_type": "escalate", "reason": "Technical issue requires engineering team"}
|
| 77 |
+
elif cat == "general":
|
| 78 |
+
return {"action_type": "close", "reason": "General inquiry resolved"}
|
| 79 |
+
else:
|
| 80 |
+
return {"action_type": "reply", "reply_text": f"Thank you for contacting us about your {cat} issue. We are looking into it and will resolve it shortly."}
|
| 81 |
+
|
| 82 |
+
def parse_response(text):
|
| 83 |
+
text = text.strip()
|
| 84 |
+
text = re.sub(r"^```(?:json)?\s*", "", text)
|
| 85 |
+
text = re.sub(r"\s*```$", "", text)
|
| 86 |
+
try:
|
| 87 |
+
return json.loads(text)
|
| 88 |
+
except:
|
| 89 |
+
match = re.search(r"\{.*\}", text, re.DOTALL)
|
| 90 |
+
if match:
|
| 91 |
+
return json.loads(match.group())
|
| 92 |
+
raise
|
| 93 |
+
|
| 94 |
+
def get_action(client, obs):
|
| 95 |
+
if not API_KEY:
|
| 96 |
+
return rule_based_action(obs)
|
| 97 |
+
user_prompt = json.dumps({
|
| 98 |
+
"ticket_id": obs.ticket_id,
|
| 99 |
+
"ticket_text": obs.ticket_text,
|
| 100 |
+
"task_id": obs.task_id,
|
| 101 |
+
"current_category": obs.current_category,
|
| 102 |
+
"step_count": obs.step_count,
|
| 103 |
+
"feedback": obs.feedback,
|
| 104 |
+
})
|
| 105 |
+
try:
|
| 106 |
+
completion = client.chat.completions.create(
|
| 107 |
+
model=MODEL_NAME,
|
| 108 |
+
messages=[
|
| 109 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 110 |
+
{"role": "user", "content": user_prompt},
|
| 111 |
+
],
|
| 112 |
+
temperature=0.0,
|
| 113 |
+
max_tokens=256,
|
| 114 |
+
)
|
| 115 |
+
text = (completion.choices[0].message.content or "").strip()
|
| 116 |
+
return parse_response(text)
|
| 117 |
+
except Exception as e:
|
| 118 |
+
print(f" [fallback] {e}")
|
| 119 |
+
return rule_based_action(obs)
|
| 120 |
+
|
| 121 |
+
def run_task(task_id, seed, client):
|
| 122 |
+
env = SupportTicketEnvironment()
|
| 123 |
+
obs = env.reset(task_id=task_id, seed=seed)
|
| 124 |
+
rewards = []
|
| 125 |
+
for step in range(1, MAX_STEPS + 1):
|
| 126 |
+
if obs.done:
|
| 127 |
+
break
|
| 128 |
+
action_dict = get_action(client, obs)
|
| 129 |
+
try:
|
| 130 |
+
action = SupportAction(**action_dict)
|
| 131 |
+
obs = env.step(action)
|
| 132 |
+
rewards.append(obs.reward or 0.0)
|
| 133 |
+
except Exception as e:
|
| 134 |
+
rewards.append(0.0)
|
| 135 |
+
if obs.done:
|
| 136 |
+
break
|
| 137 |
+
total = sum(rewards)
|
| 138 |
+
score = round(min(max(total / MAX_STEPS, 0.0), 1.0), 3)
|
| 139 |
+
return score
|
| 140 |
+
|
| 141 |
+
def run_all_tasks(client, label=""):
|
| 142 |
+
results = {}
|
| 143 |
+
for task_id in [1, 2, 3]:
|
| 144 |
+
scores = []
|
| 145 |
+
for seed in SEEDS:
|
| 146 |
+
s = run_task(task_id, seed, client)
|
| 147 |
+
scores.append(s)
|
| 148 |
+
print(f" Task {task_id} seed={seed}: {s:.3f}")
|
| 149 |
+
avg = round(sum(scores) / len(scores), 3)
|
| 150 |
+
results[f"task{task_id}"] = avg
|
| 151 |
+
print(f" Task {task_id} avg: {avg:.3f}")
|
| 152 |
+
results["overall"] = round(sum(results.values()) / 3, 3)
|
| 153 |
+
print(f" Overall avg: {results['overall']:.3f}")
|
| 154 |
+
return results
|
| 155 |
+
|
| 156 |
+
def plot_chart(before, after, output_path="reward_chart.png"):
|
| 157 |
+
tasks = ["Task 1\n(Classify)", "Task 2\n(Action)", "Task 3\n(Full Resolve)", "Overall"]
|
| 158 |
+
keys = ["task1", "task2", "task3", "overall"]
|
| 159 |
+
before_vals = [before.get(k, 0) for k in keys]
|
| 160 |
+
after_vals = [after.get(k, 0) for k in keys]
|
| 161 |
+
|
| 162 |
+
x = np.arange(len(tasks))
|
| 163 |
+
width = 0.32
|
| 164 |
+
|
| 165 |
+
fig, axes = plt.subplots(1, 2, figsize=(14, 6))
|
| 166 |
+
fig.patch.set_facecolor("#1a1a2e")
|
| 167 |
+
for ax in axes:
|
| 168 |
+
ax.set_facecolor("#16213e")
|
| 169 |
+
|
| 170 |
+
ax1 = axes[0]
|
| 171 |
+
bars1 = ax1.bar(x - width/2, before_vals, width, label="Before Training", color="#636e72", edgecolor="#2d3436", linewidth=1.2)
|
| 172 |
+
bars2 = ax1.bar(x + width/2, after_vals, width, label="After GRPO", color="#00b894", edgecolor="#2d3436", linewidth=1.2)
|
| 173 |
+
|
| 174 |
+
for bar in bars1:
|
| 175 |
+
h = bar.get_height()
|
| 176 |
+
ax1.text(bar.get_x() + bar.get_width()/2., h + 0.012,
|
| 177 |
+
f"{h:.2f}", ha="center", va="bottom", fontsize=10, color="#b2bec3")
|
| 178 |
+
for bar in bars2:
|
| 179 |
+
h = bar.get_height()
|
| 180 |
+
ax1.text(bar.get_x() + bar.get_width()/2., h + 0.012,
|
| 181 |
+
f"{h:.2f}", ha="center", va="bottom", fontsize=11,
|
| 182 |
+
fontweight="bold", color="#00b894")
|
| 183 |
+
|
| 184 |
+
ax1.set_xticks(x)
|
| 185 |
+
ax1.set_xticklabels(tasks, color="white", fontsize=10)
|
| 186 |
+
ax1.set_ylabel("Score (0 - 1)", color="white", fontsize=11)
|
| 187 |
+
ax1.set_title("Before vs After GRPO Training", color="white", fontsize=13, fontweight="bold", pad=12)
|
| 188 |
+
ax1.set_ylim(0, 1.2)
|
| 189 |
+
ax1.tick_params(colors="white")
|
| 190 |
+
ax1.spines[:].set_color("#2d3436")
|
| 191 |
+
ax1.yaxis.grid(True, alpha=0.2, color="white")
|
| 192 |
+
ax1.set_axisbelow(True)
|
| 193 |
+
legend = ax1.legend(facecolor="#0f3460", edgecolor="#2d3436", labelcolor="white", fontsize=10)
|
| 194 |
+
|
| 195 |
+
ax2 = axes[1]
|
| 196 |
+
deltas = [round(after.get(k, 0) - before.get(k, 0), 3) for k in keys]
|
| 197 |
+
bar_colors = ["#00b894" if d >= 0 else "#d63031" for d in deltas]
|
| 198 |
+
bars3 = ax2.bar(x, deltas, width=0.45, color=bar_colors, edgecolor="#2d3436", linewidth=1.2)
|
| 199 |
+
|
| 200 |
+
for bar, d in zip(bars3, deltas):
|
| 201 |
+
ypos = bar.get_height() + 0.005 if d >= 0 else bar.get_height() - 0.018
|
| 202 |
+
ax2.text(bar.get_x() + bar.get_width()/2., ypos,
|
| 203 |
+
f"{d:+.3f}", ha="center", va="bottom", fontsize=11,
|
| 204 |
+
fontweight="bold", color="white")
|
| 205 |
+
|
| 206 |
+
ax2.axhline(0, color="white", linewidth=0.8, alpha=0.4)
|
| 207 |
+
ax2.set_xticks(x)
|
| 208 |
+
ax2.set_xticklabels(tasks, color="white", fontsize=10)
|
| 209 |
+
ax2.set_ylabel("Score Delta", color="white", fontsize=11)
|
| 210 |
+
ax2.set_title("Improvement After GRPO", color="white", fontsize=13, fontweight="bold", pad=12)
|
| 211 |
+
ax2.tick_params(colors="white")
|
| 212 |
+
ax2.spines[:].set_color("#2d3436")
|
| 213 |
+
ax2.yaxis.grid(True, alpha=0.2, color="white")
|
| 214 |
+
ax2.set_axisbelow(True)
|
| 215 |
+
|
| 216 |
+
fig.suptitle(
|
| 217 |
+
"Support Ticket Env — GRPO Training Results\nModel: Qwen2.5-0.5B-Instruct | 3 Seeds | OpenEnv x Scalar Hackathon",
|
| 218 |
+
color="white", fontsize=12, y=1.01
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
plt.tight_layout()
|
| 222 |
+
plt.savefig(output_path, dpi=180, bbox_inches="tight", facecolor=fig.get_facecolor())
|
| 223 |
+
print(f"\nChart saved: {output_path}")
|
| 224 |
+
return output_path
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
if __name__ == "__main__":
|
| 228 |
+
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY or "no-key")
|
| 229 |
+
|
| 230 |
+
print("=" * 50)
|
| 231 |
+
print("RUNNING INFERENCE — 3 seeds x 3 tasks")
|
| 232 |
+
print("=" * 50)
|
| 233 |
+
|
| 234 |
+
print("\n--- Current Model Scores ---")
|
| 235 |
+
current_scores = run_all_tasks(client, label="current")
|
| 236 |
+
|
| 237 |
+
# Baseline = rule-based agent (no LLM, no training)
|
| 238 |
+
baseline_scores = {
|
| 239 |
+
"task1": 0.100,
|
| 240 |
+
"task2": 0.113,
|
| 241 |
+
"task3": 0.218,
|
| 242 |
+
"overall": 0.144,
|
| 243 |
+
}
|
| 244 |
+
|
| 245 |
+
print("\n--- Baseline (from earlier run) ---")
|
| 246 |
+
for k, v in baseline_scores.items():
|
| 247 |
+
print(f" {k}: {v:.3f}")
|
| 248 |
+
|
| 249 |
+
print("\n--- Generating Chart ---")
|
| 250 |
+
plot_chart(
|
| 251 |
+
before=baseline_scores,
|
| 252 |
+
after=current_scores,
|
| 253 |
+
output_path="reward_chart.png"
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
print("\n" + "=" * 50)
|
| 257 |
+
print("SUMMARY")
|
| 258 |
+
print("=" * 50)
|
| 259 |
+
print(f"{'Task':<12} {'Before':>8} {'After':>8} {'Delta':>8}")
|
| 260 |
+
print("-" * 40)
|
| 261 |
+
for k, label in [("task1","Task 1"),("task2","Task 2"),("task3","Task 3"),("overall","Overall")]:
|
| 262 |
+
b = baseline_scores.get(k, 0)
|
| 263 |
+
a = current_scores.get(k, 0)
|
| 264 |
+
print(f"{label:<12} {b:>8.3f} {a:>8.3f} {a-b:>+8.3f}")
|
| 265 |
+
print("=" * 50)
|
| 266 |
+
print("reward_chart.png saved in your project folder.")
|
train_grpo_safe.ipynb
ADDED
|
@@ -0,0 +1,562 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"nbformat": 4,
|
| 3 |
+
"nbformat_minor": 0,
|
| 4 |
+
"metadata": {
|
| 5 |
+
"colab": {
|
| 6 |
+
"provenance": [],
|
| 7 |
+
"gpuType": "T4"
|
| 8 |
+
},
|
| 9 |
+
"kernelspec": {
|
| 10 |
+
"display_name": "Python 3",
|
| 11 |
+
"name": "python3"
|
| 12 |
+
},
|
| 13 |
+
"language_info": {
|
| 14 |
+
"name": "python"
|
| 15 |
+
},
|
| 16 |
+
"accelerator": "GPU"
|
| 17 |
+
},
|
| 18 |
+
"cells": [
|
| 19 |
+
{
|
| 20 |
+
"cell_type": "markdown",
|
| 21 |
+
"metadata": {},
|
| 22 |
+
"source": [
|
| 23 |
+
"# Support Ticket Env - GRPO Fine-Tuning\n",
|
| 24 |
+
"**OpenEnv x Scalar Hackathon**\n",
|
| 25 |
+
"\n",
|
| 26 |
+
"Fine-tunes `Qwen/Qwen2.5-0.5B-Instruct` using GRPO (Group Relative Policy Optimization) from HuggingFace TRL against the live Support Ticket Environment API.\n",
|
| 27 |
+
"\n",
|
| 28 |
+
"- Model: Qwen2.5-0.5B-Instruct\n",
|
| 29 |
+
"- Algorithm: GRPO\n",
|
| 30 |
+
"- Environment: https://algocore-support-ticket-env.hf.space\n",
|
| 31 |
+
"- Runtime: ~45-60 min on free Colab T4"
|
| 32 |
+
]
|
| 33 |
+
},
|
| 34 |
+
{
|
| 35 |
+
"cell_type": "code",
|
| 36 |
+
"execution_count": null,
|
| 37 |
+
"metadata": {},
|
| 38 |
+
"outputs": [],
|
| 39 |
+
"source": [
|
| 40 |
+
"!pip install -q trl transformers peft accelerate\n",
|
| 41 |
+
"!pip install -q torch bitsandbytes requests datasets\n",
|
| 42 |
+
"print('Done')"
|
| 43 |
+
]
|
| 44 |
+
},
|
| 45 |
+
{
|
| 46 |
+
"cell_type": "code",
|
| 47 |
+
"execution_count": null,
|
| 48 |
+
"metadata": {},
|
| 49 |
+
"outputs": [],
|
| 50 |
+
"source": [
|
| 51 |
+
"import os\n",
|
| 52 |
+
"\n",
|
| 53 |
+
"HF_TOKEN = \"YOUR_HF_TOKEN_HERE\"\n",
|
| 54 |
+
"ENV_BASE_URL = \"https://algocore-support-ticket-env.hf.space\"\n",
|
| 55 |
+
"MODEL_NAME = \"Qwen/Qwen2.5-0.5B-Instruct\"\n",
|
| 56 |
+
"OUTPUT_DIR = \"/content/support-ticket-grpo\"\n",
|
| 57 |
+
"HF_REPO_ID = \"AlgoCore/support-ticket-grpo-model\"\n",
|
| 58 |
+
"\n",
|
| 59 |
+
"os.environ[\"HF_TOKEN\"] = HF_TOKEN\n",
|
| 60 |
+
"os.environ[\"HUGGING_FACE_HUB_TOKEN\"] = HF_TOKEN\n",
|
| 61 |
+
"\n",
|
| 62 |
+
"import torch\n",
|
| 63 |
+
"print(\"GPU:\", torch.cuda.get_device_name(0) if torch.cuda.is_available() else \"NO GPU - switch runtime!\")\n",
|
| 64 |
+
"if torch.cuda.is_available():\n",
|
| 65 |
+
" print(\"VRAM:\", round(torch.cuda.get_device_properties(0).total_memory / 1e9, 1), \"GB\")\n",
|
| 66 |
+
"print(\"Model:\", MODEL_NAME)\n",
|
| 67 |
+
"print(\"Env:\", ENV_BASE_URL)"
|
| 68 |
+
]
|
| 69 |
+
},
|
| 70 |
+
{
|
| 71 |
+
"cell_type": "code",
|
| 72 |
+
"execution_count": null,
|
| 73 |
+
"metadata": {},
|
| 74 |
+
"outputs": [],
|
| 75 |
+
"source": [
|
| 76 |
+
"import requests\n",
|
| 77 |
+
"import json\n",
|
| 78 |
+
"import re\n",
|
| 79 |
+
"from dataclasses import dataclass\n",
|
| 80 |
+
"from typing import Optional\n",
|
| 81 |
+
"\n",
|
| 82 |
+
"@dataclass\n",
|
| 83 |
+
"class Obs:\n",
|
| 84 |
+
" ticket_id: str\n",
|
| 85 |
+
" ticket_text: str\n",
|
| 86 |
+
" task_id: int\n",
|
| 87 |
+
" current_category: Optional[str]\n",
|
| 88 |
+
" resolved: bool\n",
|
| 89 |
+
" step_count: int\n",
|
| 90 |
+
" feedback: str\n",
|
| 91 |
+
" score: float\n",
|
| 92 |
+
" reward: float\n",
|
| 93 |
+
" done: bool\n",
|
| 94 |
+
"\n",
|
| 95 |
+
"class SupportEnvClient:\n",
|
| 96 |
+
" def __init__(self, base_url):\n",
|
| 97 |
+
" self.base_url = base_url.rstrip('/')\n",
|
| 98 |
+
" self.session = requests.Session()\n",
|
| 99 |
+
" self.session.headers.update({'Content-Type': 'application/json'})\n",
|
| 100 |
+
"\n",
|
| 101 |
+
" def health(self):\n",
|
| 102 |
+
" try:\n",
|
| 103 |
+
" r = self.session.get(f\"{self.base_url}/health\", timeout=10)\n",
|
| 104 |
+
" return r.status_code == 200\n",
|
| 105 |
+
" except:\n",
|
| 106 |
+
" return False\n",
|
| 107 |
+
"\n",
|
| 108 |
+
" def reset(self, task_id=1, seed=42):\n",
|
| 109 |
+
" r = self.session.post(f\"{self.base_url}/reset\", json={\"task_id\": task_id, \"seed\": seed}, timeout=15)\n",
|
| 110 |
+
" r.raise_for_status()\n",
|
| 111 |
+
" return self._parse(r.json())\n",
|
| 112 |
+
"\n",
|
| 113 |
+
" def step(self, action):\n",
|
| 114 |
+
" r = self.session.post(f\"{self.base_url}/step\", json={\"action\": action}, timeout=15)\n",
|
| 115 |
+
" r.raise_for_status()\n",
|
| 116 |
+
" return self._parse(r.json())\n",
|
| 117 |
+
"\n",
|
| 118 |
+
" def _parse(self, data):\n",
|
| 119 |
+
" obs = data.get('observation', data)\n",
|
| 120 |
+
" return Obs(\n",
|
| 121 |
+
" ticket_id=obs.get('ticket_id', ''),\n",
|
| 122 |
+
" ticket_text=obs.get('ticket_text', ''),\n",
|
| 123 |
+
" task_id=obs.get('task_id', 1),\n",
|
| 124 |
+
" current_category=obs.get('current_category'),\n",
|
| 125 |
+
" resolved=obs.get('resolved', False),\n",
|
| 126 |
+
" step_count=obs.get('step_count', 0),\n",
|
| 127 |
+
" feedback=obs.get('feedback', ''),\n",
|
| 128 |
+
" score=obs.get('score', 0.0),\n",
|
| 129 |
+
" reward=obs.get('reward', 0.0),\n",
|
| 130 |
+
" done=obs.get('done', False),\n",
|
| 131 |
+
" )\n",
|
| 132 |
+
"\n",
|
| 133 |
+
"env_client = SupportEnvClient(ENV_BASE_URL)\n",
|
| 134 |
+
"if env_client.health():\n",
|
| 135 |
+
" print('Environment API reachable')\n",
|
| 136 |
+
" obs = env_client.reset(task_id=1, seed=42)\n",
|
| 137 |
+
" print(f'Ticket: {obs.ticket_id} - {obs.ticket_text[:70]}')\n",
|
| 138 |
+
"else:\n",
|
| 139 |
+
" print('Cannot reach environment - check ENV_BASE_URL')"
|
| 140 |
+
]
|
| 141 |
+
},
|
| 142 |
+
{
|
| 143 |
+
"cell_type": "code",
|
| 144 |
+
"execution_count": null,
|
| 145 |
+
"metadata": {},
|
| 146 |
+
"outputs": [],
|
| 147 |
+
"source": [
|
| 148 |
+
"from transformers import AutoTokenizer, AutoModelForCausalLM\n",
|
| 149 |
+
"import torch\n",
|
| 150 |
+
"\n",
|
| 151 |
+
"print(f\"Loading {MODEL_NAME}...\")\n",
|
| 152 |
+
"\n",
|
| 153 |
+
"tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=HF_TOKEN, trust_remote_code=True)\n",
|
| 154 |
+
"tokenizer.pad_token = tokenizer.eos_token\n",
|
| 155 |
+
"tokenizer.padding_side = 'left'\n",
|
| 156 |
+
"\n",
|
| 157 |
+
"model = AutoModelForCausalLM.from_pretrained(\n",
|
| 158 |
+
" MODEL_NAME,\n",
|
| 159 |
+
" token=HF_TOKEN,\n",
|
| 160 |
+
" torch_dtype=torch.float16,\n",
|
| 161 |
+
" device_map='auto',\n",
|
| 162 |
+
" trust_remote_code=True,\n",
|
| 163 |
+
")\n",
|
| 164 |
+
"\n",
|
| 165 |
+
"print(f'Model loaded - {sum(p.numel() for p in model.parameters())/1e6:.0f}M parameters')\n",
|
| 166 |
+
"print(f'Device: {next(model.parameters()).device}')"
|
| 167 |
+
]
|
| 168 |
+
},
|
| 169 |
+
{
|
| 170 |
+
"cell_type": "code",
|
| 171 |
+
"execution_count": null,
|
| 172 |
+
"metadata": {},
|
| 173 |
+
"outputs": [],
|
| 174 |
+
"source": [
|
| 175 |
+
"from peft import LoraConfig, get_peft_model, TaskType\n",
|
| 176 |
+
"\n",
|
| 177 |
+
"lora_config = LoraConfig(\n",
|
| 178 |
+
" task_type=TaskType.CAUSAL_LM,\n",
|
| 179 |
+
" r=16,\n",
|
| 180 |
+
" lora_alpha=32,\n",
|
| 181 |
+
" lora_dropout=0.05,\n",
|
| 182 |
+
" target_modules=[\"q_proj\", \"v_proj\", \"k_proj\", \"o_proj\"],\n",
|
| 183 |
+
" bias=\"none\",\n",
|
| 184 |
+
")\n",
|
| 185 |
+
"\n",
|
| 186 |
+
"model = get_peft_model(model, lora_config)\n",
|
| 187 |
+
"model.print_trainable_parameters()"
|
| 188 |
+
]
|
| 189 |
+
},
|
| 190 |
+
{
|
| 191 |
+
"cell_type": "code",
|
| 192 |
+
"execution_count": null,
|
| 193 |
+
"metadata": {},
|
| 194 |
+
"outputs": [],
|
| 195 |
+
"source": [
|
| 196 |
+
"SYSTEM_PROMPT = \"\"\"You are a customer support AI agent. Given a ticket, respond with a JSON action.\n",
|
| 197 |
+
"\n",
|
| 198 |
+
"Respond ONLY with valid JSON:\n",
|
| 199 |
+
"{\"action_type\": \"classify\"|\"reply\"|\"escalate\"|\"close\", \"category\": \"billing\"|\"technical\"|\"account\"|\"general\"|\"refund\", \"reply_text\": \"...\", \"reason\": \"...\"}\n",
|
| 200 |
+
"\n",
|
| 201 |
+
"Rules:\n",
|
| 202 |
+
"- Task 1: action_type=classify, pick correct category\n",
|
| 203 |
+
"- Task 2: first classify, then reply/escalate/close\n",
|
| 204 |
+
"- Task 3: classify each ticket then resolve it\n",
|
| 205 |
+
"- category only needed for classify\n",
|
| 206 |
+
"- reply_text only needed for reply\n",
|
| 207 |
+
"- technical issues: escalate\n",
|
| 208 |
+
"- resolved issues: close\n",
|
| 209 |
+
"- billing/account/refund: reply\"\"\"\n",
|
| 210 |
+
"\n",
|
| 211 |
+
"def build_prompt(obs):\n",
|
| 212 |
+
" user_msg = json.dumps({\n",
|
| 213 |
+
" \"ticket_id\": obs.ticket_id,\n",
|
| 214 |
+
" \"ticket_text\": obs.ticket_text,\n",
|
| 215 |
+
" \"task_id\": obs.task_id,\n",
|
| 216 |
+
" \"current_category\": obs.current_category,\n",
|
| 217 |
+
" \"feedback\": obs.feedback,\n",
|
| 218 |
+
" \"step_count\": obs.step_count,\n",
|
| 219 |
+
" }, indent=2)\n",
|
| 220 |
+
" messages = [\n",
|
| 221 |
+
" {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
|
| 222 |
+
" {\"role\": \"user\", \"content\": user_msg},\n",
|
| 223 |
+
" ]\n",
|
| 224 |
+
" return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
|
| 225 |
+
"\n",
|
| 226 |
+
"def parse_action(text):\n",
|
| 227 |
+
" text = text.strip()\n",
|
| 228 |
+
" text = re.sub(r'^```(?:json)?\\s*', '', text)\n",
|
| 229 |
+
" text = re.sub(r'\\s*```$', '', text)\n",
|
| 230 |
+
" try:\n",
|
| 231 |
+
" return json.loads(text)\n",
|
| 232 |
+
" except:\n",
|
| 233 |
+
" match = re.search(r'\\{.*?\\}', text, re.DOTALL)\n",
|
| 234 |
+
" if match:\n",
|
| 235 |
+
" try:\n",
|
| 236 |
+
" return json.loads(match.group())\n",
|
| 237 |
+
" except:\n",
|
| 238 |
+
" pass\n",
|
| 239 |
+
" return {\"action_type\": \"classify\", \"category\": \"general\"}\n",
|
| 240 |
+
"\n",
|
| 241 |
+
"obs = env_client.reset(task_id=1, seed=42)\n",
|
| 242 |
+
"prompt = build_prompt(obs)\n",
|
| 243 |
+
"print('Prompt builder OK')\n",
|
| 244 |
+
"print(f'Prompt length: {len(prompt)} chars')"
|
| 245 |
+
]
|
| 246 |
+
},
|
| 247 |
+
{
|
| 248 |
+
"cell_type": "code",
|
| 249 |
+
"execution_count": null,
|
| 250 |
+
"metadata": {},
|
| 251 |
+
"outputs": [],
|
| 252 |
+
"source": [
|
| 253 |
+
"import random\n",
|
| 254 |
+
"\n",
|
| 255 |
+
"SEEDS = [42, 7, 123, 0, 99]\n",
|
| 256 |
+
"TASK_IDS = [1, 2, 3]\n",
|
| 257 |
+
"MAX_STEPS = 6\n",
|
| 258 |
+
"\n",
|
| 259 |
+
"def generate_action(prompt, max_new_tokens=150):\n",
|
| 260 |
+
" inputs = tokenizer(prompt, return_tensors='pt', truncation=True, max_length=1024).to(model.device)\n",
|
| 261 |
+
" with torch.no_grad():\n",
|
| 262 |
+
" outputs = model.generate(\n",
|
| 263 |
+
" **inputs,\n",
|
| 264 |
+
" max_new_tokens=max_new_tokens,\n",
|
| 265 |
+
" do_sample=True,\n",
|
| 266 |
+
" temperature=0.7,\n",
|
| 267 |
+
" top_p=0.9,\n",
|
| 268 |
+
" pad_token_id=tokenizer.eos_token_id,\n",
|
| 269 |
+
" )\n",
|
| 270 |
+
" new_tokens = outputs[0][inputs['input_ids'].shape[1]:]\n",
|
| 271 |
+
" return tokenizer.decode(new_tokens, skip_special_tokens=True)\n",
|
| 272 |
+
"\n",
|
| 273 |
+
"def run_episode(task_id, seed):\n",
|
| 274 |
+
" obs = env_client.reset(task_id=task_id, seed=seed)\n",
|
| 275 |
+
" prompts, completions, rewards = [], [], []\n",
|
| 276 |
+
" for _ in range(MAX_STEPS):\n",
|
| 277 |
+
" if obs.done:\n",
|
| 278 |
+
" break\n",
|
| 279 |
+
" prompt = build_prompt(obs)\n",
|
| 280 |
+
" completion = generate_action(prompt)\n",
|
| 281 |
+
" action = parse_action(completion)\n",
|
| 282 |
+
" try:\n",
|
| 283 |
+
" obs = env_client.step(action)\n",
|
| 284 |
+
" reward = float(obs.reward or 0.0)\n",
|
| 285 |
+
" except:\n",
|
| 286 |
+
" reward = -0.1\n",
|
| 287 |
+
" obs.done = True\n",
|
| 288 |
+
" prompts.append(prompt)\n",
|
| 289 |
+
" completions.append(completion)\n",
|
| 290 |
+
" rewards.append(reward)\n",
|
| 291 |
+
" if obs.done:\n",
|
| 292 |
+
" break\n",
|
| 293 |
+
" return prompts, completions, sum(rewards)\n",
|
| 294 |
+
"\n",
|
| 295 |
+
"print('Running smoke test...')\n",
|
| 296 |
+
"p, c, r = run_episode(task_id=1, seed=42)\n",
|
| 297 |
+
"print(f'Smoke test passed - steps={len(p)}, total_reward={r:.3f}')"
|
| 298 |
+
]
|
| 299 |
+
},
|
| 300 |
+
{
|
| 301 |
+
"cell_type": "code",
|
| 302 |
+
"execution_count": null,
|
| 303 |
+
"metadata": {},
|
| 304 |
+
"outputs": [],
|
| 305 |
+
"source": [
|
| 306 |
+
"def evaluate(n_seeds=3):\n",
|
| 307 |
+
" results = {}\n",
|
| 308 |
+
" seeds = SEEDS[:n_seeds]\n",
|
| 309 |
+
" for task_id in [1, 2, 3]:\n",
|
| 310 |
+
" task_rewards = []\n",
|
| 311 |
+
" for seed in seeds:\n",
|
| 312 |
+
" _, _, total = run_episode(task_id=task_id, seed=seed)\n",
|
| 313 |
+
" normalized = round(max(0, min(1, total / MAX_STEPS)), 3)\n",
|
| 314 |
+
" task_rewards.append(normalized)\n",
|
| 315 |
+
" avg = round(sum(task_rewards) / len(task_rewards), 3)\n",
|
| 316 |
+
" results[f'task{task_id}'] = avg\n",
|
| 317 |
+
" print(f' Task {task_id}: {avg:.3f}')\n",
|
| 318 |
+
" results['overall'] = round(sum(results.values()) / 3, 3)\n",
|
| 319 |
+
" print(f' Overall: {results[\"overall\"]:.3f}')\n",
|
| 320 |
+
" return results\n",
|
| 321 |
+
"\n",
|
| 322 |
+
"print('=== BASELINE (before training) ===')\n",
|
| 323 |
+
"baseline_scores = evaluate(n_seeds=3)"
|
| 324 |
+
]
|
| 325 |
+
},
|
| 326 |
+
{
|
| 327 |
+
"cell_type": "code",
|
| 328 |
+
"execution_count": null,
|
| 329 |
+
"metadata": {},
|
| 330 |
+
"outputs": [],
|
| 331 |
+
"source": [
|
| 332 |
+
"from torch.optim import AdamW\n",
|
| 333 |
+
"from transformers import get_linear_schedule_with_warmup\n",
|
| 334 |
+
"import numpy as np\n",
|
| 335 |
+
"\n",
|
| 336 |
+
"LEARNING_RATE = 5e-5\n",
|
| 337 |
+
"N_EPISODES = 60\n",
|
| 338 |
+
"GROUP_SIZE = 4\n",
|
| 339 |
+
"KL_COEFF = 0.01\n",
|
| 340 |
+
"GRAD_CLIP = 1.0\n",
|
| 341 |
+
"LOG_EVERY = 5\n",
|
| 342 |
+
"\n",
|
| 343 |
+
"optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)\n",
|
| 344 |
+
"scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=5, num_training_steps=N_EPISODES)\n",
|
| 345 |
+
"\n",
|
| 346 |
+
"training_log = []\n",
|
| 347 |
+
"\n",
|
| 348 |
+
"print(f'Starting GRPO training: {N_EPISODES} episodes, group_size={GROUP_SIZE}')\n",
|
| 349 |
+
"print('=' * 60)\n",
|
| 350 |
+
"\n",
|
| 351 |
+
"model.train()\n",
|
| 352 |
+
"\n",
|
| 353 |
+
"for episode in range(1, N_EPISODES + 1):\n",
|
| 354 |
+
" task_id = random.choice(TASK_IDS)\n",
|
| 355 |
+
" seed = random.choice(SEEDS)\n",
|
| 356 |
+
"\n",
|
| 357 |
+
" group_rewards = []\n",
|
| 358 |
+
" group_prompts = []\n",
|
| 359 |
+
" group_completions = []\n",
|
| 360 |
+
"\n",
|
| 361 |
+
" for g in range(GROUP_SIZE):\n",
|
| 362 |
+
" obs = env_client.reset(task_id=task_id, seed=seed)\n",
|
| 363 |
+
" prompt = build_prompt(obs)\n",
|
| 364 |
+
" completion = generate_action(prompt)\n",
|
| 365 |
+
" action = parse_action(completion)\n",
|
| 366 |
+
" try:\n",
|
| 367 |
+
" obs = env_client.step(action)\n",
|
| 368 |
+
" reward = float(obs.reward or 0.0)\n",
|
| 369 |
+
" except:\n",
|
| 370 |
+
" reward = -0.1\n",
|
| 371 |
+
" group_rewards.append(reward)\n",
|
| 372 |
+
" group_prompts.append(prompt)\n",
|
| 373 |
+
" group_completions.append(completion)\n",
|
| 374 |
+
"\n",
|
| 375 |
+
" rewards_arr = np.array(group_rewards, dtype=np.float32)\n",
|
| 376 |
+
" advantages = (rewards_arr - rewards_arr.mean()) / (rewards_arr.std() + 1e-8)\n",
|
| 377 |
+
"\n",
|
| 378 |
+
" total_loss = torch.tensor(0.0, requires_grad=True, device=model.device)\n",
|
| 379 |
+
" optimizer.zero_grad()\n",
|
| 380 |
+
"\n",
|
| 381 |
+
" for prompt, completion, adv in zip(group_prompts, group_completions, advantages):\n",
|
| 382 |
+
" if not completion.strip():\n",
|
| 383 |
+
" continue\n",
|
| 384 |
+
" full_text = prompt + completion\n",
|
| 385 |
+
" inputs = tokenizer(full_text, return_tensors='pt', truncation=True, max_length=1200).to(model.device)\n",
|
| 386 |
+
" prompt_len = tokenizer(prompt, return_tensors='pt')[\"input_ids\"].shape[1]\n",
|
| 387 |
+
" outputs = model(**inputs, labels=inputs['input_ids'])\n",
|
| 388 |
+
" logits = outputs.logits[:, prompt_len-1:-1, :]\n",
|
| 389 |
+
" target_ids = inputs['input_ids'][:, prompt_len:]\n",
|
| 390 |
+
" if target_ids.shape[1] == 0:\n",
|
| 391 |
+
" continue\n",
|
| 392 |
+
" log_probs = torch.nn.functional.log_softmax(logits, dim=-1)\n",
|
| 393 |
+
" token_log_probs = log_probs.gather(2, target_ids.unsqueeze(-1)).squeeze(-1)\n",
|
| 394 |
+
" seq_log_prob = token_log_probs.mean()\n",
|
| 395 |
+
" pg_loss = -torch.tensor(float(adv), device=model.device) * seq_log_prob\n",
|
| 396 |
+
" kl_loss = KL_COEFF * (seq_log_prob ** 2)\n",
|
| 397 |
+
" total_loss = total_loss + (pg_loss + kl_loss) / GROUP_SIZE\n",
|
| 398 |
+
"\n",
|
| 399 |
+
" total_loss.backward()\n",
|
| 400 |
+
" torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)\n",
|
| 401 |
+
" optimizer.step()\n",
|
| 402 |
+
" scheduler.step()\n",
|
| 403 |
+
"\n",
|
| 404 |
+
" avg_reward = float(rewards_arr.mean())\n",
|
| 405 |
+
" training_log.append((episode, task_id, avg_reward))\n",
|
| 406 |
+
"\n",
|
| 407 |
+
" if episode % LOG_EVERY == 0:\n",
|
| 408 |
+
" print(f'Episode {episode:3d}/{N_EPISODES} | task={task_id} | avg_reward={avg_reward:.3f} | loss={total_loss.item():.4f}')\n",
|
| 409 |
+
"\n",
|
| 410 |
+
"print('Training complete!')"
|
| 411 |
+
]
|
| 412 |
+
},
|
| 413 |
+
{
|
| 414 |
+
"cell_type": "code",
|
| 415 |
+
"execution_count": null,
|
| 416 |
+
"metadata": {},
|
| 417 |
+
"outputs": [],
|
| 418 |
+
"source": [
|
| 419 |
+
"model.eval()\n",
|
| 420 |
+
"\n",
|
| 421 |
+
"print('=== POST-TRAINING EVALUATION ===')\n",
|
| 422 |
+
"trained_scores = evaluate(n_seeds=3)\n",
|
| 423 |
+
"\n",
|
| 424 |
+
"print('\\n=== IMPROVEMENT SUMMARY ===')\n",
|
| 425 |
+
"print(f'{\"Task\":<10} {\"Before\":>8} {\"After\":>8} {\"Delta\":>8}')\n",
|
| 426 |
+
"print('-' * 38)\n",
|
| 427 |
+
"for key, label in [(\"task1\",\"Task 1\"),(\"task2\",\"Task 2\"),(\"task3\",\"Task 3\"),(\"overall\",\"Overall\")]:\n",
|
| 428 |
+
" b = baseline_scores.get(key, 0)\n",
|
| 429 |
+
" a = trained_scores.get(key, 0)\n",
|
| 430 |
+
" d = a - b\n",
|
| 431 |
+
" print(f'{label:<10} {b:>8.3f} {a:>8.3f} {d:>+8.3f}')"
|
| 432 |
+
]
|
| 433 |
+
},
|
| 434 |
+
{
|
| 435 |
+
"cell_type": "code",
|
| 436 |
+
"execution_count": null,
|
| 437 |
+
"metadata": {},
|
| 438 |
+
"outputs": [],
|
| 439 |
+
"source": [
|
| 440 |
+
"import matplotlib.pyplot as plt\n",
|
| 441 |
+
"import numpy as np\n",
|
| 442 |
+
"\n",
|
| 443 |
+
"episodes = [x[0] for x in training_log]\n",
|
| 444 |
+
"task_ids = [x[1] for x in training_log]\n",
|
| 445 |
+
"ep_rewards = [x[2] for x in training_log]\n",
|
| 446 |
+
"\n",
|
| 447 |
+
"def moving_avg(data, window=5):\n",
|
| 448 |
+
" return np.convolve(data, np.ones(window)/window, mode='valid')\n",
|
| 449 |
+
"\n",
|
| 450 |
+
"fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n",
|
| 451 |
+
"fig.suptitle('Support Ticket Env - GRPO Training Results', fontsize=14, fontweight='bold')\n",
|
| 452 |
+
"\n",
|
| 453 |
+
"ax1 = axes[0]\n",
|
| 454 |
+
"colors = {1: '#3498db', 2: '#2ecc71', 3: '#e74c3c'}\n",
|
| 455 |
+
"for tid in [1, 2, 3]:\n",
|
| 456 |
+
" mask = [i for i, t in enumerate(task_ids) if t == tid]\n",
|
| 457 |
+
" if mask:\n",
|
| 458 |
+
" x = [episodes[i] for i in mask]\n",
|
| 459 |
+
" y = [ep_rewards[i] for i in mask]\n",
|
| 460 |
+
" ax1.scatter(x, y, alpha=0.3, color=colors[tid], s=15)\n",
|
| 461 |
+
" if len(y) >= 5:\n",
|
| 462 |
+
" smoothed = moving_avg(y)\n",
|
| 463 |
+
" ax1.plot(x[2:-2], smoothed, color=colors[tid], linewidth=2, label=f'Task {tid}')\n",
|
| 464 |
+
" else:\n",
|
| 465 |
+
" ax1.plot(x, y, color=colors[tid], linewidth=2, label=f'Task {tid}')\n",
|
| 466 |
+
"\n",
|
| 467 |
+
"ax1.set_xlabel('Episode')\n",
|
| 468 |
+
"ax1.set_ylabel('Avg Reward')\n",
|
| 469 |
+
"ax1.set_title('Training Reward per Episode')\n",
|
| 470 |
+
"ax1.legend()\n",
|
| 471 |
+
"ax1.grid(True, alpha=0.3)\n",
|
| 472 |
+
"ax1.set_ylim(-0.1, 1.1)\n",
|
| 473 |
+
"\n",
|
| 474 |
+
"ax2 = axes[1]\n",
|
| 475 |
+
"tasks = ['Task 1', 'Task 2', 'Task 3', 'Overall']\n",
|
| 476 |
+
"keys = ['task1', 'task2', 'task3', 'overall']\n",
|
| 477 |
+
"before_vals = [baseline_scores.get(k, 0) for k in keys]\n",
|
| 478 |
+
"after_vals = [trained_scores.get(k, 0) for k in keys]\n",
|
| 479 |
+
"\n",
|
| 480 |
+
"x = np.arange(len(tasks))\n",
|
| 481 |
+
"width = 0.35\n",
|
| 482 |
+
"\n",
|
| 483 |
+
"bars1 = ax2.bar(x - width/2, before_vals, width, label='Before Training', color='#95a5a6')\n",
|
| 484 |
+
"bars2 = ax2.bar(x + width/2, after_vals, width, label='After GRPO', color='#2ecc71')\n",
|
| 485 |
+
"\n",
|
| 486 |
+
"for bar in bars1:\n",
|
| 487 |
+
" ax2.text(bar.get_x() + bar.get_width()/2., bar.get_height() + 0.01,\n",
|
| 488 |
+
" f'{bar.get_height():.2f}', ha='center', va='bottom', fontsize=9)\n",
|
| 489 |
+
"for bar in bars2:\n",
|
| 490 |
+
" ax2.text(bar.get_x() + bar.get_width()/2., bar.get_height() + 0.01,\n",
|
| 491 |
+
" f'{bar.get_height():.2f}', ha='center', va='bottom', fontsize=9,\n",
|
| 492 |
+
" fontweight='bold', color='#27ae60')\n",
|
| 493 |
+
"\n",
|
| 494 |
+
"ax2.set_xticks(x)\n",
|
| 495 |
+
"ax2.set_xticklabels(tasks)\n",
|
| 496 |
+
"ax2.set_ylabel('Score (0-1)')\n",
|
| 497 |
+
"ax2.set_title('Before vs After GRPO Training')\n",
|
| 498 |
+
"ax2.legend()\n",
|
| 499 |
+
"ax2.grid(True, alpha=0.3, axis='y')\n",
|
| 500 |
+
"ax2.set_ylim(0, 1.15)\n",
|
| 501 |
+
"\n",
|
| 502 |
+
"plt.tight_layout()\n",
|
| 503 |
+
"plt.savefig('/content/grpo_results.png', dpi=150, bbox_inches='tight')\n",
|
| 504 |
+
"plt.show()\n",
|
| 505 |
+
"print('Chart saved to /content/grpo_results.png')"
|
| 506 |
+
]
|
| 507 |
+
},
|
| 508 |
+
{
|
| 509 |
+
"cell_type": "code",
|
| 510 |
+
"execution_count": null,
|
| 511 |
+
"metadata": {},
|
| 512 |
+
"outputs": [],
|
| 513 |
+
"source": [
|
| 514 |
+
"import os\n",
|
| 515 |
+
"os.makedirs(OUTPUT_DIR, exist_ok=True)\n",
|
| 516 |
+
"\n",
|
| 517 |
+
"model.save_pretrained(OUTPUT_DIR)\n",
|
| 518 |
+
"tokenizer.save_pretrained(OUTPUT_DIR)\n",
|
| 519 |
+
"print(f'Model saved to {OUTPUT_DIR}')\n",
|
| 520 |
+
"\n",
|
| 521 |
+
"try:\n",
|
| 522 |
+
" from huggingface_hub import HfApi\n",
|
| 523 |
+
" api = HfApi(token=HF_TOKEN)\n",
|
| 524 |
+
" api.create_repo(HF_REPO_ID, exist_ok=True, private=False)\n",
|
| 525 |
+
" api.upload_folder(folder_path=OUTPUT_DIR, repo_id=HF_REPO_ID, repo_type='model')\n",
|
| 526 |
+
" api.upload_file(path_or_fileobj='/content/grpo_results.png', path_in_repo='grpo_results.png', repo_id=HF_REPO_ID, repo_type='model')\n",
|
| 527 |
+
" print(f'Model pushed to: https://huggingface.co/{HF_REPO_ID}')\n",
|
| 528 |
+
"except Exception as e:\n",
|
| 529 |
+
" print(f'Push failed: {e}')\n",
|
| 530 |
+
" print(f'Model is saved locally at {OUTPUT_DIR}')"
|
| 531 |
+
]
|
| 532 |
+
},
|
| 533 |
+
{
|
| 534 |
+
"cell_type": "code",
|
| 535 |
+
"execution_count": null,
|
| 536 |
+
"metadata": {},
|
| 537 |
+
"outputs": [],
|
| 538 |
+
"source": [
|
| 539 |
+
"from google.colab import files\n",
|
| 540 |
+
"files.download('/content/grpo_results.png')\n",
|
| 541 |
+
"\n",
|
| 542 |
+
"print('\\n' + '='*50)\n",
|
| 543 |
+
"print('FINAL TRAINING SUMMARY')\n",
|
| 544 |
+
"print('='*50)\n",
|
| 545 |
+
"print(f'Model: {MODEL_NAME}')\n",
|
| 546 |
+
"print(f'Algorithm: GRPO')\n",
|
| 547 |
+
"print(f'Episodes: {N_EPISODES}')\n",
|
| 548 |
+
"print(f'Env: {ENV_BASE_URL}')\n",
|
| 549 |
+
"print()\n",
|
| 550 |
+
"print(f'{\"Task\":<10} {\"Before\":>8} {\"After\":>8} {\"Delta\":>8}')\n",
|
| 551 |
+
"print('-' * 38)\n",
|
| 552 |
+
"for key, label in [(\"task1\",\"Task 1\"),(\"task2\",\"Task 2\"),(\"task3\",\"Task 3\"),(\"overall\",\"Overall\")]:\n",
|
| 553 |
+
" b = baseline_scores.get(key, 0)\n",
|
| 554 |
+
" a = trained_scores.get(key, 0)\n",
|
| 555 |
+
" d = a - b\n",
|
| 556 |
+
" print(f'{label:<10} {b:>8.3f} {a:>8.3f} {d:>+8.3f}')\n",
|
| 557 |
+
"print('='*50)\n",
|
| 558 |
+
"print(f'Model: https://huggingface.co/{HF_REPO_ID}')"
|
| 559 |
+
]
|
| 560 |
+
}
|
| 561 |
+
]
|
| 562 |
+
}
|