Spaces:
Sleeping
Sleeping
Update src/App.js
#1
by Shree2604 - opened
- src/App.js +33 -17
src/App.js
CHANGED
|
@@ -116,7 +116,7 @@ const ThemeToggle = ({ theme, onToggle }) => {
|
|
| 116 |
// βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 117 |
// SCORE BADGE
|
| 118 |
// βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 119 |
-
const ScoreBadge = ({ score }) => {
|
| 120 |
const pct = parseFloat(score) * 100;
|
| 121 |
const color = pct >= 60 ? "#22c55e" : pct >= 35 ? "#f59e0b" : "#ef4444";
|
| 122 |
return (
|
|
@@ -125,7 +125,7 @@ const ScoreBadge = ({ score }) => {
|
|
| 125 |
border: `1px solid ${color}55`,
|
| 126 |
borderRadius: 6, padding: "2px 10px",
|
| 127 |
fontFamily: "monospace", fontWeight: 700, fontSize: 13,
|
| 128 |
-
}}>
|
| 129 |
);
|
| 130 |
};
|
| 131 |
|
|
@@ -270,6 +270,8 @@ export default function App() {
|
|
| 270 |
const [rewardOutput, setRewardOutput] = useState("");
|
| 271 |
const [grpoOutput, setGrpoOutput] = useState("");
|
| 272 |
const [rewardScore, setRewardScore] = useState(null);
|
|
|
|
|
|
|
| 273 |
// Breakdown states removed - no longer used
|
| 274 |
const [loading, setLoading] = useState(false);
|
| 275 |
const [dragging, setDragging] = useState(false);
|
|
@@ -280,7 +282,7 @@ export default function App() {
|
|
| 280 |
if (!image || !imageFile) return;
|
| 281 |
setLoading(true);
|
| 282 |
setSftOutput(""); setRewardOutput(""); setGrpoOutput("");
|
| 283 |
-
setRewardScore(null);
|
| 284 |
const BASE = "";
|
| 285 |
|
| 286 |
try {
|
|
@@ -299,6 +301,10 @@ export default function App() {
|
|
| 299 |
const rmData = await rmRes.json();
|
| 300 |
setRewardOutput(rmData.feedback);
|
| 301 |
setRewardScore(parseFloat(rmData.score).toFixed(2));
|
|
|
|
|
|
|
|
|
|
|
|
|
| 302 |
|
| 303 |
// 3. GRPO + its reward breakdown
|
| 304 |
const grpoForm = new FormData();
|
|
@@ -307,6 +313,10 @@ export default function App() {
|
|
| 307 |
const grpoRes = await fetch(`${BASE}/grpo_reward`, { method: "POST", body: grpoForm });
|
| 308 |
const grpoData = await grpoRes.json();
|
| 309 |
setGrpoOutput(grpoData.report);
|
|
|
|
|
|
|
|
|
|
|
|
|
| 310 |
|
| 311 |
} catch (err) {
|
| 312 |
console.error("Inference error:", err);
|
|
@@ -333,12 +343,18 @@ export default function App() {
|
|
| 333 |
const clearAll = () => {
|
| 334 |
setImage(null); setImageFile(null);
|
| 335 |
setSftOutput(""); setRewardOutput(""); setGrpoOutput("");
|
| 336 |
-
setRewardScore(null);
|
| 337 |
if (fileRef.current) fileRef.current.value = "";
|
| 338 |
};
|
| 339 |
|
| 340 |
-
|
| 341 |
-
const
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 342 |
|
| 343 |
return (
|
| 344 |
<div style={{
|
|
@@ -529,7 +545,7 @@ export default function App() {
|
|
| 529 |
<OutputCard
|
| 530 |
theme={theme} title="SFT Model Output " icon="π§ " accent="#3b82f6"
|
| 531 |
content={sftOutput} loading={loading && !sftOutput}
|
| 532 |
-
badge={
|
| 533 |
/>
|
| 534 |
</div>
|
| 535 |
|
|
@@ -572,7 +588,7 @@ export default function App() {
|
|
| 572 |
<OutputCard
|
| 573 |
theme={theme} title="NeMo Gym Output" icon="π―" accent="#22c55e"
|
| 574 |
content={grpoOutput} loading={loading && rewardOutput && !grpoOutput}
|
| 575 |
-
badge={
|
| 576 |
/>
|
| 577 |
</div>
|
| 578 |
|
|
@@ -583,14 +599,14 @@ export default function App() {
|
|
| 583 |
boxShadow: theme.cardShadow, transition: "background .3s, border-color .3s",
|
| 584 |
}}>
|
| 585 |
<div style={{ fontSize: 10, color: theme.textMuted, textTransform: "uppercase", letterSpacing: 2, fontWeight: 700, marginBottom: 16 }}>
|
| 586 |
-
π
|
| 587 |
</div>
|
| 588 |
<div style={{ display: "grid", gridTemplateColumns: "1fr 1fr", gap: 16 }}>
|
| 589 |
{[
|
| 590 |
-
{ label: "SFT (Original)", score:
|
| 591 |
-
{ label: "GRPO (Final)", score:
|
| 592 |
].map(({ label, score, color }) => {
|
| 593 |
-
const
|
| 594 |
return (
|
| 595 |
<div key={label} style={{
|
| 596 |
background: theme.surfaceAlt, borderRadius: 10, padding: 16,
|
|
@@ -598,17 +614,17 @@ export default function App() {
|
|
| 598 |
}}>
|
| 599 |
<div style={{ fontSize: 12, color: theme.textMuted, marginBottom: 8 }}>{label}</div>
|
| 600 |
<div style={{ fontSize: 28, fontWeight: 800, color, fontFamily: "monospace" }}>
|
| 601 |
-
{
|
| 602 |
</div>
|
| 603 |
<div style={{ marginTop: 10, height: 5, background: theme.barTrack, borderRadius: 3 }}>
|
| 604 |
<div style={{
|
| 605 |
-
width: `${
|
| 606 |
background: color, borderRadius: 3, transition: "width 1.2s ease",
|
| 607 |
}} />
|
| 608 |
</div>
|
| 609 |
-
{parseFloat(
|
| 610 |
<div style={{ fontSize: 11, color: "#22c55e", marginTop: 7, fontWeight: 600 }}>
|
| 611 |
-
β² +{(parseFloat(
|
| 612 |
</div>
|
| 613 |
)}
|
| 614 |
</div>
|
|
@@ -666,4 +682,4 @@ export default function App() {
|
|
| 666 |
</footer>
|
| 667 |
</div>
|
| 668 |
);
|
| 669 |
-
}
|
|
|
|
| 116 |
// βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 117 |
// SCORE BADGE
|
| 118 |
// βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 119 |
+
const ScoreBadge = ({ score, label = "ROUGE-L" }) => {
|
| 120 |
const pct = parseFloat(score) * 100;
|
| 121 |
const color = pct >= 60 ? "#22c55e" : pct >= 35 ? "#f59e0b" : "#ef4444";
|
| 122 |
return (
|
|
|
|
| 125 |
border: `1px solid ${color}55`,
|
| 126 |
borderRadius: 6, padding: "2px 10px",
|
| 127 |
fontFamily: "monospace", fontWeight: 700, fontSize: 13,
|
| 128 |
+
}}>{label}: {pct.toFixed(1)}%</span>
|
| 129 |
);
|
| 130 |
};
|
| 131 |
|
|
|
|
| 270 |
const [rewardOutput, setRewardOutput] = useState("");
|
| 271 |
const [grpoOutput, setGrpoOutput] = useState("");
|
| 272 |
const [rewardScore, setRewardScore] = useState(null);
|
| 273 |
+
const [hfJudgeSFT, setHfJudgeSFT] = useState(null);
|
| 274 |
+
const [hfJudgeGRPO, setHfJudgeGRPO] = useState(null);
|
| 275 |
// Breakdown states removed - no longer used
|
| 276 |
const [loading, setLoading] = useState(false);
|
| 277 |
const [dragging, setDragging] = useState(false);
|
|
|
|
| 282 |
if (!image || !imageFile) return;
|
| 283 |
setLoading(true);
|
| 284 |
setSftOutput(""); setRewardOutput(""); setGrpoOutput("");
|
| 285 |
+
setRewardScore(null); setHfJudgeSFT(null); setHfJudgeGRPO(null);
|
| 286 |
const BASE = "";
|
| 287 |
|
| 288 |
try {
|
|
|
|
| 301 |
const rmData = await rmRes.json();
|
| 302 |
setRewardOutput(rmData.feedback);
|
| 303 |
setRewardScore(parseFloat(rmData.score).toFixed(2));
|
| 304 |
+
// Use hf_judge score for comparison when ground truth is provided
|
| 305 |
+
if (groundTruth && rmData.breakdown) {
|
| 306 |
+
setHfJudgeSFT(rmData.breakdown.hf_judge ?? null);
|
| 307 |
+
}
|
| 308 |
|
| 309 |
// 3. GRPO + its reward breakdown
|
| 310 |
const grpoForm = new FormData();
|
|
|
|
| 313 |
const grpoRes = await fetch(`${BASE}/grpo_reward`, { method: "POST", body: grpoForm });
|
| 314 |
const grpoData = await grpoRes.json();
|
| 315 |
setGrpoOutput(grpoData.report);
|
| 316 |
+
// Use hf_judge score for comparison when ground truth is provided
|
| 317 |
+
if (groundTruth && grpoData.breakdown) {
|
| 318 |
+
setHfJudgeGRPO(grpoData.breakdown.hf_judge ?? null);
|
| 319 |
+
}
|
| 320 |
|
| 321 |
} catch (err) {
|
| 322 |
console.error("Inference error:", err);
|
|
|
|
| 343 |
const clearAll = () => {
|
| 344 |
setImage(null); setImageFile(null);
|
| 345 |
setSftOutput(""); setRewardOutput(""); setGrpoOutput("");
|
| 346 |
+
setRewardScore(null); setHfJudgeSFT(null); setHfJudgeGRPO(null);
|
| 347 |
if (fileRef.current) fileRef.current.value = "";
|
| 348 |
};
|
| 349 |
|
| 350 |
+
// When ground truth is provided, use HF judge score from backend; otherwise fall back to ROUGE-L
|
| 351 |
+
const scoreSFT = groundTruth && sftOutput
|
| 352 |
+
? (hfJudgeSFT !== null ? hfJudgeSFT : ROUGE_L(sftOutput, groundTruth))
|
| 353 |
+
: null;
|
| 354 |
+
const scoreGRPO = groundTruth && grpoOutput
|
| 355 |
+
? (hfJudgeGRPO !== null ? hfJudgeGRPO : ROUGE_L(grpoOutput, groundTruth))
|
| 356 |
+
: null;
|
| 357 |
+
const scoreLabel = groundTruth ? (hfJudgeSFT !== null || hfJudgeGRPO !== null ? "HF Judge" : "ROUGE-L") : "ROUGE-L";
|
| 358 |
|
| 359 |
return (
|
| 360 |
<div style={{
|
|
|
|
| 545 |
<OutputCard
|
| 546 |
theme={theme} title="SFT Model Output " icon="π§ " accent="#3b82f6"
|
| 547 |
content={sftOutput} loading={loading && !sftOutput}
|
| 548 |
+
badge={scoreSFT !== null && <ScoreBadge score={scoreSFT} label={scoreLabel} />}
|
| 549 |
/>
|
| 550 |
</div>
|
| 551 |
|
|
|
|
| 588 |
<OutputCard
|
| 589 |
theme={theme} title="NeMo Gym Output" icon="π―" accent="#22c55e"
|
| 590 |
content={grpoOutput} loading={loading && rewardOutput && !grpoOutput}
|
| 591 |
+
badge={scoreGRPO !== null && <ScoreBadge score={scoreGRPO} label={scoreLabel} />}
|
| 592 |
/>
|
| 593 |
</div>
|
| 594 |
|
|
|
|
| 599 |
boxShadow: theme.cardShadow, transition: "background .3s, border-color .3s",
|
| 600 |
}}>
|
| 601 |
<div style={{ fontSize: 10, color: theme.textMuted, textTransform: "uppercase", letterSpacing: 2, fontWeight: 700, marginBottom: 16 }}>
|
| 602 |
+
π {scoreLabel} Comparison vs Ground Truth
|
| 603 |
</div>
|
| 604 |
<div style={{ display: "grid", gridTemplateColumns: "1fr 1fr", gap: 16 }}>
|
| 605 |
{[
|
| 606 |
+
{ label: "SFT (Original)", score: scoreSFT, color: "#3b82f6" },
|
| 607 |
+
{ label: "GRPO (Final)", score: scoreGRPO, color: "#22c55e" },
|
| 608 |
].map(({ label, score, color }) => {
|
| 609 |
+
const displayPct = (parseFloat(score) * 100).toFixed(1);
|
| 610 |
return (
|
| 611 |
<div key={label} style={{
|
| 612 |
background: theme.surfaceAlt, borderRadius: 10, padding: 16,
|
|
|
|
| 614 |
}}>
|
| 615 |
<div style={{ fontSize: 12, color: theme.textMuted, marginBottom: 8 }}>{label}</div>
|
| 616 |
<div style={{ fontSize: 28, fontWeight: 800, color, fontFamily: "monospace" }}>
|
| 617 |
+
{displayPct}%
|
| 618 |
</div>
|
| 619 |
<div style={{ marginTop: 10, height: 5, background: theme.barTrack, borderRadius: 3 }}>
|
| 620 |
<div style={{
|
| 621 |
+
width: `${displayPct}%`, height: "100%",
|
| 622 |
background: color, borderRadius: 3, transition: "width 1.2s ease",
|
| 623 |
}} />
|
| 624 |
</div>
|
| 625 |
+
{parseFloat(displayPct) > parseFloat(scoreSFT) * 100 && label.includes("GRPO") && (
|
| 626 |
<div style={{ fontSize: 11, color: "#22c55e", marginTop: 7, fontWeight: 600 }}>
|
| 627 |
+
β² +{(parseFloat(displayPct) - parseFloat(scoreSFT) * 100).toFixed(1)}% improvement
|
| 628 |
</div>
|
| 629 |
)}
|
| 630 |
</div>
|
|
|
|
| 682 |
</footer>
|
| 683 |
</div>
|
| 684 |
);
|
| 685 |
+
}
|