Spaces:
Sleeping
Sleeping
Use pruned 536M model with vocab mapping support
Browse files- Model shrunk from 722M to 536M via vocabulary pruning (255K→80K tokens)
- Add custom autoregressive generation (no HF .generate() dependency)
- Add vocab_mapping.json support for tokenizer→model ID translation
- Update model size references throughout
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- app.py +953 -520
- requirements.txt +2 -0
app.py
CHANGED
|
@@ -1,14 +1,16 @@
|
|
| 1 |
"""
|
| 2 |
-
Aetheris Playground
|
| 3 |
-
|
| 4 |
|
| 5 |
-
Aetheris is a ~
|
| 6 |
-
CohereLabs/tiny-aya-global (3.35B).
|
| 7 |
-
|
| 8 |
"""
|
| 9 |
|
| 10 |
import json
|
| 11 |
import os
|
|
|
|
|
|
|
| 12 |
from datetime import datetime
|
| 13 |
from pathlib import Path
|
| 14 |
from typing import Any
|
|
@@ -16,55 +18,126 @@ from typing import Any
|
|
| 16 |
import gradio as gr
|
| 17 |
import numpy as np
|
| 18 |
import pandas as pd
|
| 19 |
-
import plotly.express as px
|
| 20 |
import plotly.graph_objects as go
|
|
|
|
| 21 |
from plotly.subplots import make_subplots
|
| 22 |
|
| 23 |
# ---------------------------------------------------------------------------
|
| 24 |
-
# Constants
|
| 25 |
# ---------------------------------------------------------------------------
|
| 26 |
|
| 27 |
BENCHMARK_PATH = Path(__file__).parent / "benchmark_results.json"
|
| 28 |
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
"
|
| 32 |
-
"
|
| 33 |
-
"
|
| 34 |
-
"
|
| 35 |
-
"
|
| 36 |
-
"
|
| 37 |
-
"
|
| 38 |
-
"
|
| 39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
}
|
| 41 |
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
"
|
| 47 |
-
"
|
| 48 |
-
"
|
| 49 |
-
"
|
| 50 |
-
"
|
| 51 |
-
"
|
| 52 |
-
"
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
|
| 65 |
|
| 66 |
def load_benchmarks() -> dict[str, Any]:
|
| 67 |
-
"""Load benchmark results from JSON file."""
|
| 68 |
if BENCHMARK_PATH.exists():
|
| 69 |
with open(BENCHMARK_PATH) as f:
|
| 70 |
return json.load(f)
|
|
@@ -74,268 +147,541 @@ def load_benchmarks() -> dict[str, Any]:
|
|
| 74 |
BENCHMARKS = load_benchmarks()
|
| 75 |
|
| 76 |
# ---------------------------------------------------------------------------
|
| 77 |
-
# Model
|
| 78 |
# ---------------------------------------------------------------------------
|
| 79 |
|
| 80 |
model = None
|
| 81 |
tokenizer = None
|
|
|
|
|
|
|
| 82 |
|
| 83 |
|
| 84 |
-
def load_model()
|
| 85 |
-
"""Load
|
| 86 |
-
global model, tokenizer
|
| 87 |
-
if
|
| 88 |
return model, tokenizer
|
| 89 |
|
|
|
|
| 90 |
try:
|
| 91 |
-
from
|
| 92 |
-
import
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
return model, tokenizer
|
|
|
|
| 103 |
except Exception as e:
|
| 104 |
-
print(f"
|
|
|
|
|
|
|
|
|
|
| 105 |
return None, None
|
| 106 |
|
| 107 |
|
| 108 |
# ---------------------------------------------------------------------------
|
| 109 |
-
#
|
| 110 |
# ---------------------------------------------------------------------------
|
| 111 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
|
| 123 |
m, tok = load_model()
|
| 124 |
|
|
|
|
| 125 |
if m is None or tok is None:
|
| 126 |
-
|
| 127 |
-
"
|
| 128 |
-
"
|
| 129 |
-
"
|
| 130 |
-
f"Your prompt ({CORE_LANGUAGES.get(language, language)}): {prompt}"
|
| 131 |
)
|
|
|
|
|
|
|
|
|
|
| 132 |
|
| 133 |
try:
|
| 134 |
import torch
|
| 135 |
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
|
| 150 |
-
generated = tok.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
|
| 151 |
-
return generated.strip()
|
| 152 |
except Exception as e:
|
| 153 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
|
| 155 |
|
| 156 |
# ---------------------------------------------------------------------------
|
| 157 |
-
# Benchmark visualizations
|
| 158 |
# ---------------------------------------------------------------------------
|
| 159 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
|
| 161 |
def make_model_info_table() -> pd.DataFrame:
|
| 162 |
-
"""Create side-by-side model comparison table."""
|
| 163 |
info = BENCHMARKS.get("model_info", {})
|
| 164 |
teacher = info.get("teacher", {})
|
| 165 |
student = info.get("student", {})
|
| 166 |
-
|
| 167 |
rows = [
|
| 168 |
-
("Organization",
|
|
|
|
| 169 |
("Parameters", f"{teacher.get('params_m', 3350)}M", f"{student.get('params_m', 800)}M"),
|
| 170 |
-
("Architecture", teacher.get("architecture", "Transformer
|
| 171 |
("Layers", str(teacher.get("layers", 36)), str(student.get("layers", 24))),
|
| 172 |
-
("Hidden Dim", str(teacher.get("hidden_dim", 2048)), str(student.get("hidden_dim",
|
| 173 |
-
("Attention",
|
| 174 |
("Experts", "N/A (dense)", f"{student.get('num_experts', 4)} (top-{student.get('top_k', 1)})"),
|
| 175 |
-
("Vocab
|
| 176 |
-
("
|
| 177 |
-
("Languages", str(teacher.get("languages", 101)), str(student.get("languages", 101))),
|
| 178 |
("Compression", "1.0x (baseline)", f"{student.get('compression_ratio', 4.2)}x"),
|
| 179 |
]
|
| 180 |
-
|
| 181 |
-
return pd.DataFrame(rows, columns=["Metric", "Teacher (tiny-aya-global)", "Student (Aetheris)"])
|
| 182 |
|
| 183 |
|
| 184 |
def make_benchmark_chart(benchmark_key: str, title: str) -> go.Figure:
|
| 185 |
-
"""Create grouped bar chart for a benchmark (mGSM or XCOPA)."""
|
| 186 |
data = BENCHMARKS.get(benchmark_key, {})
|
| 187 |
langs = data.get("languages", {})
|
| 188 |
avg = data.get("average", {})
|
| 189 |
|
| 190 |
lang_codes = list(langs.keys())
|
| 191 |
-
lang_names = [
|
| 192 |
teacher_scores = [langs[lc]["teacher"] for lc in lang_codes]
|
| 193 |
student_scores = [langs[lc]["student"] for lc in lang_codes]
|
| 194 |
|
| 195 |
-
|
| 196 |
-
lang_names.append("AVERAGE")
|
| 197 |
teacher_scores.append(avg.get("teacher", 0))
|
| 198 |
student_scores.append(avg.get("student", 0))
|
| 199 |
|
| 200 |
fig = go.Figure()
|
| 201 |
fig.add_trace(go.Bar(
|
| 202 |
name="Teacher (tiny-aya-global)",
|
| 203 |
-
x=lang_names,
|
| 204 |
-
|
| 205 |
-
marker_color=TEACHER_COLOR,
|
| 206 |
text=[f"{s:.1f}" for s in teacher_scores],
|
| 207 |
-
textposition="outside",
|
| 208 |
-
textfont=dict(size=10),
|
| 209 |
))
|
| 210 |
fig.add_trace(go.Bar(
|
| 211 |
name="Student (Aetheris)",
|
| 212 |
-
x=lang_names,
|
| 213 |
-
|
| 214 |
-
marker_color=STUDENT_COLOR,
|
| 215 |
text=[f"{s:.1f}" for s in student_scores],
|
| 216 |
-
textposition="outside",
|
| 217 |
-
textfont=dict(size=10),
|
| 218 |
))
|
| 219 |
|
| 220 |
-
metric = data.get("metric", "accuracy (%)")
|
| 221 |
fig.update_layout(
|
| 222 |
-
|
|
|
|
| 223 |
xaxis_title="Language",
|
| 224 |
-
yaxis_title=metric,
|
| 225 |
barmode="group",
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
margin=dict(t=80, b=60),
|
| 230 |
)
|
| 231 |
-
|
| 232 |
return fig
|
| 233 |
|
| 234 |
|
| 235 |
def make_retention_chart() -> go.Figure:
|
| 236 |
-
"""Create a retention heatmap showing student/teacher ratio per benchmark per language."""
|
| 237 |
mgsm = BENCHMARKS.get("mgsm", {}).get("languages", {})
|
| 238 |
xcopa = BENCHMARKS.get("xcopa", {}).get("languages", {})
|
| 239 |
|
| 240 |
lang_codes = list(mgsm.keys())
|
| 241 |
-
lang_names = [
|
| 242 |
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
for lc in lang_codes
|
| 246 |
-
]
|
| 247 |
-
xcopa_retention = [
|
| 248 |
-
xcopa[lc]["student"] / xcopa[lc]["teacher"] * 100 if lc in xcopa and xcopa[lc]["teacher"] > 0 else 0
|
| 249 |
-
for lc in lang_codes
|
| 250 |
-
]
|
| 251 |
|
| 252 |
fig = go.Figure()
|
| 253 |
-
fig.add_trace(go.Scatter(
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
textposition="top center",
|
| 261 |
-
textfont=dict(size=9),
|
| 262 |
-
))
|
| 263 |
-
fig.add_trace(go.Scatter(
|
| 264 |
-
x=lang_names, y=xcopa_retention,
|
| 265 |
-
mode="lines+markers+text",
|
| 266 |
-
name="XCOPA Retention",
|
| 267 |
-
line=dict(color=STUDENT_COLOR, width=2),
|
| 268 |
-
marker=dict(size=8),
|
| 269 |
-
text=[f"{r:.0f}%" for r in xcopa_retention],
|
| 270 |
-
textposition="bottom center",
|
| 271 |
-
textfont=dict(size=9),
|
| 272 |
-
))
|
| 273 |
-
|
| 274 |
-
# Reference line at 80% retention
|
| 275 |
-
fig.add_hline(y=80, line_dash="dash", line_color="#f59e0b",
|
| 276 |
annotation_text="80% retention target", annotation_position="top right")
|
| 277 |
|
| 278 |
-
fig.update_layout(
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
template=PLOTLY_TEMPLATE,
|
| 283 |
-
height=400,
|
| 284 |
-
yaxis=dict(range=[60, 105]),
|
| 285 |
-
legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
|
| 286 |
-
)
|
| 287 |
-
|
| 288 |
return fig
|
| 289 |
|
| 290 |
|
| 291 |
def make_throughput_chart() -> go.Figure:
|
| 292 |
-
"""Create throughput comparison bar charts."""
|
| 293 |
tp = BENCHMARKS.get("throughput", {})
|
| 294 |
teacher = tp.get("teacher", {})
|
| 295 |
student = tp.get("student", {})
|
| 296 |
|
| 297 |
-
metrics = ["Tokens/sec", "TTFT (ms)", "Memory (MB)"
|
| 298 |
-
teacher_vals = [
|
| 299 |
-
|
| 300 |
-
teacher.get("ttft_ms", 0),
|
| 301 |
-
teacher.get("memory_mb", 0),
|
| 302 |
-
teacher.get("latency_p99_ms", 0),
|
| 303 |
-
]
|
| 304 |
-
student_vals = [
|
| 305 |
-
student.get("tokens_per_sec", 0),
|
| 306 |
-
student.get("ttft_ms", 0),
|
| 307 |
-
student.get("memory_mb", 0),
|
| 308 |
-
student.get("latency_p99_ms", 0),
|
| 309 |
-
]
|
| 310 |
-
|
| 311 |
-
fig = make_subplots(
|
| 312 |
-
rows=1, cols=4,
|
| 313 |
-
subplot_titles=metrics,
|
| 314 |
-
horizontal_spacing=0.08,
|
| 315 |
-
)
|
| 316 |
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
fig.add_trace(go.Bar(
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
text=[str(tv), str(sv)],
|
| 323 |
-
textposition="outside",
|
| 324 |
-
showlegend=False,
|
| 325 |
-
), row=1, col=col)
|
| 326 |
-
|
| 327 |
-
fig.update_layout(
|
| 328 |
-
title=dict(text="Inference Performance (batch=1, seq_len=256)", font=dict(size=18)),
|
| 329 |
-
template=PLOTLY_TEMPLATE,
|
| 330 |
-
height=380,
|
| 331 |
-
margin=dict(t=80, b=40),
|
| 332 |
-
)
|
| 333 |
|
|
|
|
|
|
|
| 334 |
return fig
|
| 335 |
|
| 336 |
|
| 337 |
def make_equity_chart() -> go.Figure:
|
| 338 |
-
"""Create multilingual equity visualization."""
|
| 339 |
equity = BENCHMARKS.get("equity", {})
|
| 340 |
families = equity.get("language_families", {})
|
| 341 |
|
|
@@ -343,405 +689,492 @@ def make_equity_chart() -> go.Figure:
|
|
| 343 |
des_vals = [families[n]["des"] for n in names]
|
| 344 |
retention_vals = [families[n]["avg_retention"] * 100 for n in names]
|
| 345 |
|
| 346 |
-
fig = make_subplots(
|
| 347 |
-
rows=1, cols=2,
|
| 348 |
-
subplot_titles=[
|
| 349 |
-
"Degradation Equity Score by Family (lower = more equitable)",
|
| 350 |
-
"Average Quality Retention by Family (%)"
|
| 351 |
-
],
|
| 352 |
-
horizontal_spacing=0.12,
|
| 353 |
-
)
|
| 354 |
-
|
| 355 |
-
# Sort by DES for the left chart
|
| 356 |
sorted_idx = np.argsort(des_vals)[::-1]
|
| 357 |
sorted_names = [names[i] for i in sorted_idx]
|
| 358 |
sorted_des = [des_vals[i] for i in sorted_idx]
|
| 359 |
|
| 360 |
-
colors = [
|
| 361 |
-
"#ef4444" if d > 0.4 else "#f59e0b" if d > 0.3 else STUDENT_COLOR
|
| 362 |
-
for d in sorted_des
|
| 363 |
-
]
|
| 364 |
-
|
| 365 |
-
fig.add_trace(go.Bar(
|
| 366 |
-
x=sorted_des, y=sorted_names,
|
| 367 |
-
orientation="h",
|
| 368 |
-
marker_color=colors,
|
| 369 |
-
text=[f"{d:.2f}" for d in sorted_des],
|
| 370 |
-
textposition="outside",
|
| 371 |
-
showlegend=False,
|
| 372 |
-
), row=1, col=1)
|
| 373 |
-
|
| 374 |
-
# Sort by retention for the right chart
|
| 375 |
-
sorted_idx_r = np.argsort(retention_vals)
|
| 376 |
-
sorted_names_r = [names[i] for i in sorted_idx_r]
|
| 377 |
-
sorted_ret = [retention_vals[i] for i in sorted_idx_r]
|
| 378 |
-
|
| 379 |
-
colors_r = [
|
| 380 |
-
"#ef4444" if r < 77 else "#f59e0b" if r < 82 else STUDENT_COLOR
|
| 381 |
-
for r in sorted_ret
|
| 382 |
-
]
|
| 383 |
-
|
| 384 |
-
fig.add_trace(go.Bar(
|
| 385 |
-
x=sorted_ret, y=sorted_names_r,
|
| 386 |
-
orientation="h",
|
| 387 |
-
marker_color=colors_r,
|
| 388 |
-
text=[f"{r:.0f}%" for r in sorted_ret],
|
| 389 |
-
textposition="outside",
|
| 390 |
-
showlegend=False,
|
| 391 |
-
), row=1, col=2)
|
| 392 |
-
|
| 393 |
-
fig.update_layout(
|
| 394 |
-
template=PLOTLY_TEMPLATE,
|
| 395 |
-
height=400,
|
| 396 |
-
margin=dict(t=60, b=40, l=120),
|
| 397 |
-
)
|
| 398 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 399 |
return fig
|
| 400 |
|
| 401 |
|
| 402 |
def make_training_chart() -> go.Figure:
|
| 403 |
-
"""Create training progress curves."""
|
| 404 |
progress = BENCHMARKS.get("training_progress", {})
|
| 405 |
steps = progress.get("steps", [])
|
| 406 |
kl = progress.get("kl_divergence", [])
|
| 407 |
cka = progress.get("cka_alignment", [])
|
| 408 |
loss = progress.get("total_loss", [])
|
| 409 |
-
des = progress.get("des_score", [])
|
| 410 |
-
|
| 411 |
-
fig = make_subplots(
|
| 412 |
-
rows=2, cols=2,
|
| 413 |
-
subplot_titles=[
|
| 414 |
-
"KL Divergence (lower = better)",
|
| 415 |
-
"CKA Alignment (higher = better)",
|
| 416 |
-
"Total Loss",
|
| 417 |
-
"DES Score (lower = more equitable)",
|
| 418 |
-
],
|
| 419 |
-
vertical_spacing=0.15,
|
| 420 |
-
horizontal_spacing=0.1,
|
| 421 |
-
)
|
| 422 |
|
| 423 |
-
fig
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
fig.add_trace(go.Scatter(
|
| 430 |
-
|
| 431 |
-
line=dict(color=STUDENT_COLOR, width=2),
|
| 432 |
-
marker=dict(size=5), showlegend=False,
|
| 433 |
-
), row=1, col=2)
|
| 434 |
-
|
| 435 |
-
fig.add_trace(go.Scatter(
|
| 436 |
-
x=steps, y=loss, mode="lines+markers",
|
| 437 |
-
line=dict(color=ACCENT_COLOR_2, width=2),
|
| 438 |
-
marker=dict(size=5), showlegend=False,
|
| 439 |
-
), row=2, col=1)
|
| 440 |
-
|
| 441 |
-
fig.add_trace(go.Scatter(
|
| 442 |
-
x=steps, y=des, mode="lines+markers",
|
| 443 |
-
line=dict(color="#f59e0b", width=2),
|
| 444 |
-
marker=dict(size=5), showlegend=False,
|
| 445 |
-
), row=2, col=2)
|
| 446 |
-
|
| 447 |
-
# Add target lines
|
| 448 |
-
fig.add_hline(y=0.75, line_dash="dash", line_color="#4ade80",
|
| 449 |
-
annotation_text="target", row=1, col=2)
|
| 450 |
-
fig.add_hline(y=0.30, line_dash="dash", line_color="#4ade80",
|
| 451 |
-
annotation_text="target", row=2, col=2)
|
| 452 |
-
|
| 453 |
-
fig.update_xaxes(title_text="Training Steps", row=2, col=1)
|
| 454 |
-
fig.update_xaxes(title_text="Training Steps", row=2, col=2)
|
| 455 |
|
| 456 |
-
fig.
|
| 457 |
-
template=PLOTLY_TEMPLATE,
|
| 458 |
-
height=550,
|
| 459 |
-
margin=dict(t=60, b=60),
|
| 460 |
-
)
|
| 461 |
|
|
|
|
|
|
|
|
|
|
| 462 |
return fig
|
| 463 |
|
| 464 |
|
| 465 |
-
def make_benchmark_table(benchmark_key: str) -> pd.DataFrame:
|
| 466 |
-
"""Create a results table for mGSM or XCOPA."""
|
| 467 |
-
data = BENCHMARKS.get(benchmark_key, {})
|
| 468 |
-
langs = data.get("languages", {})
|
| 469 |
-
avg = data.get("average", {})
|
| 470 |
-
|
| 471 |
-
rows = []
|
| 472 |
-
for lc, scores in langs.items():
|
| 473 |
-
t = scores["teacher"]
|
| 474 |
-
s = scores["student"]
|
| 475 |
-
retention = (s / t * 100) if t > 0 else 0
|
| 476 |
-
rows.append({
|
| 477 |
-
"Language": f"{CORE_LANGUAGES.get(lc, lc)} ({lc})",
|
| 478 |
-
"Teacher": f"{t:.1f}",
|
| 479 |
-
"Student": f"{s:.1f}",
|
| 480 |
-
"Retention": f"{retention:.0f}%",
|
| 481 |
-
"Gap": f"{t - s:+.1f}",
|
| 482 |
-
})
|
| 483 |
-
|
| 484 |
-
# Average row
|
| 485 |
-
t_avg = avg.get("teacher", 0)
|
| 486 |
-
s_avg = avg.get("student", 0)
|
| 487 |
-
ret_avg = (s_avg / t_avg * 100) if t_avg > 0 else 0
|
| 488 |
-
rows.append({
|
| 489 |
-
"Language": "AVERAGE",
|
| 490 |
-
"Teacher": f"{t_avg:.1f}",
|
| 491 |
-
"Student": f"{s_avg:.1f}",
|
| 492 |
-
"Retention": f"{ret_avg:.0f}%",
|
| 493 |
-
"Gap": f"{t_avg - s_avg:+.1f}",
|
| 494 |
-
})
|
| 495 |
-
|
| 496 |
-
return pd.DataFrame(rows)
|
| 497 |
-
|
| 498 |
-
|
| 499 |
# ---------------------------------------------------------------------------
|
| 500 |
-
#
|
| 501 |
# ---------------------------------------------------------------------------
|
| 502 |
|
| 503 |
CUSTOM_CSS = """
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
|
| 508 |
-
|
|
|
|
|
|
|
| 509 |
}
|
| 510 |
-
|
| 511 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 512 |
margin: 0 0 4px 0 !important;
|
| 513 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 514 |
}
|
| 515 |
-
.
|
| 516 |
-
|
| 517 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 518 |
}
|
| 519 |
-
.stat-
|
| 520 |
-
|
| 521 |
-
|
| 522 |
-
border-
|
| 523 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 524 |
text-align: center;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 525 |
}
|
| 526 |
-
|
| 527 |
-
|
| 528 |
-
|
| 529 |
-
|
| 530 |
}
|
| 531 |
-
|
| 532 |
-
|
| 533 |
-
|
| 534 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 535 |
}
|
| 536 |
-
|
| 537 |
-
|
| 538 |
-
|
| 539 |
-
|
| 540 |
-
|
| 541 |
-
|
| 542 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 543 |
}
|
| 544 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 545 |
"""
|
| 546 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 547 |
theme = gr.themes.Soft(
|
| 548 |
-
primary_hue="
|
| 549 |
-
secondary_hue="
|
| 550 |
neutral_hue="slate",
|
| 551 |
font=gr.themes.GoogleFont("Inter"),
|
| 552 |
).set(
|
| 553 |
-
body_background_fill="#
|
| 554 |
-
body_background_fill_dark="#
|
| 555 |
-
block_background_fill="#
|
| 556 |
-
block_background_fill_dark="#
|
| 557 |
-
block_border_color="
|
| 558 |
-
block_border_color_dark="
|
| 559 |
-
input_background_fill="#
|
| 560 |
-
input_background_fill_dark="#
|
| 561 |
-
button_primary_background_fill="#
|
| 562 |
-
button_primary_background_fill_dark="#
|
| 563 |
-
button_primary_background_fill_hover="#
|
| 564 |
-
button_primary_background_fill_hover_dark="#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 565 |
)
|
| 566 |
|
| 567 |
last_updated = BENCHMARKS.get("metadata", {}).get("last_updated", "N/A")
|
| 568 |
|
| 569 |
-
with gr.Blocks(theme=theme, css=CUSTOM_CSS, title="Aetheris
|
| 570 |
|
| 571 |
-
#
|
| 572 |
gr.HTML("""
|
| 573 |
-
<div class="
|
| 574 |
-
<
|
| 575 |
-
|
| 576 |
-
|
| 577 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 578 |
</p>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 579 |
</div>
|
| 580 |
""")
|
| 581 |
|
| 582 |
-
#
|
| 583 |
-
with gr.
|
| 584 |
-
|
| 585 |
-
gr.HTML("""
|
| 586 |
-
|
| 587 |
-
|
| 588 |
-
gr.HTML(
|
| 589 |
-
|
| 590 |
-
|
| 591 |
-
|
| 592 |
-
|
| 593 |
-
|
|
|
|
|
|
|
| 594 |
gr.Markdown(
|
| 595 |
-
"
|
| 596 |
-
"
|
| 597 |
-
"Currently running CohereLabs/tiny-aya-global as a demo.*"
|
| 598 |
)
|
| 599 |
|
| 600 |
with gr.Row():
|
| 601 |
-
with gr.Column(scale=
|
| 602 |
-
|
| 603 |
-
|
| 604 |
-
|
| 605 |
-
|
|
|
|
|
|
|
|
|
|
| 606 |
)
|
| 607 |
with gr.Row():
|
| 608 |
-
|
| 609 |
-
|
| 610 |
-
|
| 611 |
-
|
| 612 |
-
|
| 613 |
-
)
|
| 614 |
-
max_tokens_slider = gr.Slider(
|
| 615 |
-
minimum=32, maximum=512, value=128, step=16,
|
| 616 |
-
label="Max Tokens",
|
| 617 |
-
scale=2,
|
| 618 |
-
)
|
| 619 |
-
temp_slider = gr.Slider(
|
| 620 |
-
minimum=0.0, maximum=1.0, value=0.7, step=0.05,
|
| 621 |
-
label="Temperature",
|
| 622 |
-
scale=2,
|
| 623 |
)
|
|
|
|
| 624 |
|
| 625 |
-
with gr.
|
| 626 |
-
|
| 627 |
-
|
| 628 |
-
|
| 629 |
-
|
| 630 |
-
output_text = gr.Textbox(
|
| 631 |
-
label="Generated Text",
|
| 632 |
-
lines=12,
|
| 633 |
-
interactive=False,
|
| 634 |
-
)
|
| 635 |
|
| 636 |
-
|
| 637 |
-
|
| 638 |
-
|
| 639 |
-
|
| 640 |
-
|
| 641 |
-
|
| 642 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 643 |
)
|
|
|
|
| 644 |
|
| 645 |
-
|
| 646 |
-
|
| 647 |
-
|
| 648 |
-
|
| 649 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 650 |
|
| 651 |
-
# ======
|
| 652 |
-
# TAB 2: Benchmark Dashboard
|
| 653 |
-
# ===================================================================
|
| 654 |
with gr.Tab("Benchmarks", id="benchmarks"):
|
| 655 |
|
| 656 |
gr.HTML(f"""
|
| 657 |
-
<div class="
|
| 658 |
-
|
| 659 |
-
Real
|
| 660 |
-
|
| 661 |
</div>
|
| 662 |
""")
|
| 663 |
|
| 664 |
-
#
|
| 665 |
-
gr.
|
|
|
|
| 666 |
model_table = make_model_info_table()
|
| 667 |
-
gr.Dataframe(
|
| 668 |
-
|
| 669 |
-
|
| 670 |
-
|
| 671 |
-
|
| 672 |
-
)
|
| 673 |
-
|
| 674 |
-
|
| 675 |
-
|
| 676 |
-
gr.
|
| 677 |
-
|
| 678 |
-
)
|
| 679 |
-
|
| 680 |
-
|
| 681 |
-
|
| 682 |
-
|
| 683 |
-
|
| 684 |
-
|
| 685 |
-
|
| 686 |
-
# --- XCOPA ---
|
| 687 |
-
gr.Markdown("## XCOPA: Causal Reasoning")
|
| 688 |
-
gr.Markdown(
|
| 689 |
-
"*Cross-lingual Choice of Plausible Alternatives -- tests causal and commonsense reasoning.*"
|
| 690 |
-
)
|
| 691 |
-
with gr.Row():
|
| 692 |
-
with gr.Column(scale=3):
|
| 693 |
-
gr.Plot(value=make_benchmark_chart("xcopa", "XCOPA Accuracy by Language"))
|
| 694 |
-
with gr.Column(scale=2):
|
| 695 |
-
xcopa_table = make_benchmark_table("xcopa")
|
| 696 |
-
gr.Dataframe(value=xcopa_table, interactive=False, wrap=True)
|
| 697 |
-
|
| 698 |
-
# --- Quality Retention ---
|
| 699 |
-
gr.Markdown("## Quality Retention Across Languages")
|
| 700 |
-
gr.Markdown(
|
| 701 |
-
"*Percentage of teacher performance retained by the student at 4.2x compression. "
|
| 702 |
-
"Target: >80% retention across all languages.*"
|
| 703 |
-
)
|
| 704 |
gr.Plot(value=make_retention_chart())
|
| 705 |
|
| 706 |
-
#
|
| 707 |
-
gr.
|
| 708 |
-
gr.Markdown(
|
| 709 |
-
|
| 710 |
-
"Student achieves ~3x throughput with 4x memory reduction.*"
|
| 711 |
-
)
|
| 712 |
gr.Plot(value=make_throughput_chart())
|
| 713 |
|
| 714 |
-
#
|
| 715 |
-
gr.
|
|
|
|
| 716 |
gr.Markdown(
|
| 717 |
-
"*DES
|
| 718 |
-
"
|
| 719 |
-
"Red
|
| 720 |
)
|
| 721 |
gr.Plot(value=make_equity_chart())
|
| 722 |
|
| 723 |
equity_data = BENCHMARKS.get("equity", {})
|
| 724 |
high_risk = equity_data.get("high_risk_languages", [])
|
| 725 |
-
|
| 726 |
-
f"**High-risk languages** (
|
| 727 |
-
f"`{'`, `'.join(high_risk)}`"
|
| 728 |
-
)
|
| 729 |
|
| 730 |
-
#
|
| 731 |
-
gr.
|
| 732 |
-
gr.Markdown(
|
| 733 |
-
|
| 734 |
-
"alignment between student and teacher layers. Dashed lines indicate targets.*"
|
| 735 |
-
)
|
| 736 |
gr.Plot(value=make_training_chart())
|
| 737 |
|
| 738 |
-
|
| 739 |
-
|
| 740 |
-
|
| 741 |
-
|
| 742 |
-
|
| 743 |
-
|
| 744 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 745 |
|
| 746 |
|
| 747 |
if __name__ == "__main__":
|
|
|
|
| 1 |
"""
|
| 2 |
+
Aetheris Playground — CohereLabs x Wayy Research
|
| 3 |
+
A polyglot chat interface and benchmark dashboard for the Aetheris model.
|
| 4 |
|
| 5 |
+
Aetheris is a ~536M Hybrid Mamba-MoE multilingual model distilled from
|
| 6 |
+
CohereLabs/tiny-aya-global (3.35B). Chat in any of 67 languages and
|
| 7 |
+
the model responds in kind.
|
| 8 |
"""
|
| 9 |
|
| 10 |
import json
|
| 11 |
import os
|
| 12 |
+
import re
|
| 13 |
+
import sys
|
| 14 |
from datetime import datetime
|
| 15 |
from pathlib import Path
|
| 16 |
from typing import Any
|
|
|
|
| 18 |
import gradio as gr
|
| 19 |
import numpy as np
|
| 20 |
import pandas as pd
|
|
|
|
| 21 |
import plotly.graph_objects as go
|
| 22 |
+
import torch
|
| 23 |
from plotly.subplots import make_subplots
|
| 24 |
|
| 25 |
# ---------------------------------------------------------------------------
|
| 26 |
+
# Constants & Data
|
| 27 |
# ---------------------------------------------------------------------------
|
| 28 |
|
| 29 |
BENCHMARK_PATH = Path(__file__).parent / "benchmark_results.json"
|
| 30 |
|
| 31 |
+
LANGUAGES = {
|
| 32 |
+
# ── Europe: Romance ──
|
| 33 |
+
"en": {"name": "English", "family": "Indo-European", "region": "Europe", "greeting": "Hello!", "flag": "🇬🇧"},
|
| 34 |
+
"fr": {"name": "French", "family": "Indo-European", "region": "Europe", "greeting": "Bonjour!", "flag": "🇫🇷"},
|
| 35 |
+
"es": {"name": "Spanish", "family": "Indo-European", "region": "Europe", "greeting": "¡Hola!", "flag": "🇪🇸"},
|
| 36 |
+
"pt": {"name": "Portuguese", "family": "Indo-European", "region": "Europe", "greeting": "Olá!", "flag": "🇧🇷"},
|
| 37 |
+
"it": {"name": "Italian", "family": "Indo-European", "region": "Europe", "greeting": "Ciao!", "flag": "🇮🇹"},
|
| 38 |
+
"ro": {"name": "Romanian", "family": "Indo-European", "region": "Europe", "greeting": "Bună!", "flag": "🇷🇴"},
|
| 39 |
+
"ca": {"name": "Catalan", "family": "Indo-European", "region": "Europe", "greeting": "Hola!", "flag": "🇪🇸"},
|
| 40 |
+
"gl": {"name": "Galician", "family": "Indo-European", "region": "Europe", "greeting": "Ola!", "flag": "🇪🇸"},
|
| 41 |
+
# ── Europe: Germanic ──
|
| 42 |
+
"de": {"name": "German", "family": "Indo-European", "region": "Europe", "greeting": "Hallo!", "flag": "🇩🇪"},
|
| 43 |
+
"nl": {"name": "Dutch", "family": "Indo-European", "region": "Europe", "greeting": "Hallo!", "flag": "🇳🇱"},
|
| 44 |
+
"da": {"name": "Danish", "family": "Indo-European", "region": "Europe", "greeting": "Hej!", "flag": "🇩🇰"},
|
| 45 |
+
"sv": {"name": "Swedish", "family": "Indo-European", "region": "Europe", "greeting": "Hej!", "flag": "🇸🇪"},
|
| 46 |
+
"no": {"name": "Norwegian", "family": "Indo-European", "region": "Europe", "greeting": "Hei!", "flag": "🇳🇴"},
|
| 47 |
+
# ── Europe: Slavic ──
|
| 48 |
+
"ru": {"name": "Russian", "family": "Indo-European", "region": "Europe", "greeting": "Привет!", "flag": "🇷🇺"},
|
| 49 |
+
"uk": {"name": "Ukrainian", "family": "Indo-European", "region": "Europe", "greeting": "Привіт!", "flag": "🇺🇦"},
|
| 50 |
+
"pl": {"name": "Polish", "family": "Indo-European", "region": "Europe", "greeting": "Cześć!", "flag": "🇵🇱"},
|
| 51 |
+
"cs": {"name": "Czech", "family": "Indo-European", "region": "Europe", "greeting": "Ahoj!", "flag": "🇨🇿"},
|
| 52 |
+
"sk": {"name": "Slovak", "family": "Indo-European", "region": "Europe", "greeting": "Ahoj!", "flag": "🇸🇰"},
|
| 53 |
+
"hr": {"name": "Croatian", "family": "Indo-European", "region": "Europe", "greeting": "Bok!", "flag": "🇭🇷"},
|
| 54 |
+
"sr": {"name": "Serbian", "family": "Indo-European", "region": "Europe", "greeting": "Здраво!", "flag": "🇷🇸"},
|
| 55 |
+
"sl": {"name": "Slovenian", "family": "Indo-European", "region": "Europe", "greeting": "Živjo!", "flag": "🇸🇮"},
|
| 56 |
+
"bg": {"name": "Bulgarian", "family": "Indo-European", "region": "Europe", "greeting": "Здравей!", "flag": "🇧🇬"},
|
| 57 |
+
# ── Europe: Baltic ──
|
| 58 |
+
"lv": {"name": "Latvian", "family": "Indo-European", "region": "Europe", "greeting": "Sveiki!", "flag": "🇱🇻"},
|
| 59 |
+
"lt": {"name": "Lithuanian", "family": "Indo-European", "region": "Europe", "greeting": "Labas!", "flag": "🇱🇹"},
|
| 60 |
+
# ── Europe: Other ──
|
| 61 |
+
"el": {"name": "Greek", "family": "Indo-European", "region": "Europe", "greeting": "Γεια σου!", "flag": "🇬🇷"},
|
| 62 |
+
"et": {"name": "Estonian", "family": "Uralic", "region": "Europe", "greeting": "Tere!", "flag": "🇪🇪"},
|
| 63 |
+
"fi": {"name": "Finnish", "family": "Uralic", "region": "Europe", "greeting": "Hei!", "flag": "🇫🇮"},
|
| 64 |
+
"hu": {"name": "Hungarian", "family": "Uralic", "region": "Europe", "greeting": "Szia!", "flag": "🇭🇺"},
|
| 65 |
+
"eu": {"name": "Basque", "family": "Language isolate", "region": "Europe", "greeting": "Kaixo!", "flag": "🇪🇸"},
|
| 66 |
+
"cy": {"name": "Welsh", "family": "Indo-European", "region": "Europe", "greeting": "Helo!", "flag": "🏴\U000E0067\U000E0062\U000E0077\U000E006C\U000E0073\U000E007F"},
|
| 67 |
+
"ga": {"name": "Irish", "family": "Indo-European", "region": "Europe", "greeting": "Dia duit!", "flag": "🇮🇪"},
|
| 68 |
+
"mt": {"name": "Maltese", "family": "Afroasiatic", "region": "Europe", "greeting": "Merħba!", "flag": "🇲🇹"},
|
| 69 |
+
# ── Middle East ──
|
| 70 |
+
"ar": {"name": "Arabic", "family": "Afroasiatic", "region": "Middle East", "greeting": "مرحبا!", "flag": "🇸🇦"},
|
| 71 |
+
"fa": {"name": "Persian", "family": "Indo-European", "region": "Middle East", "greeting": "سلام!", "flag": "🇮🇷"},
|
| 72 |
+
"he": {"name": "Hebrew", "family": "Afroasiatic", "region": "Middle East", "greeting": "!שלום", "flag": "🇮🇱"},
|
| 73 |
+
"tr": {"name": "Turkish", "family": "Turkic", "region": "Middle East", "greeting": "Merhaba!", "flag": "🇹🇷"},
|
| 74 |
+
# ── South Asia ──
|
| 75 |
+
"hi": {"name": "Hindi", "family": "Indo-European", "region": "South Asia", "greeting": "नमस्ते!", "flag": "🇮🇳"},
|
| 76 |
+
"ur": {"name": "Urdu", "family": "Indo-European", "region": "South Asia", "greeting": "!ہیلو", "flag": "🇵🇰"},
|
| 77 |
+
"bn": {"name": "Bengali", "family": "Indo-European", "region": "South Asia", "greeting": "হ্যালো!", "flag": "🇧🇩"},
|
| 78 |
+
"mr": {"name": "Marathi", "family": "Indo-European", "region": "South Asia", "greeting": "नमस्कार!", "flag": "🇮🇳"},
|
| 79 |
+
"gu": {"name": "Gujarati", "family": "Indo-European", "region": "South Asia", "greeting": "નમસ્તે!", "flag": "🇮🇳"},
|
| 80 |
+
"pa": {"name": "Punjabi", "family": "Indo-European", "region": "South Asia", "greeting": "ਸਤ ਸ੍ਰੀ ਅਕਾਲ!", "flag": "🇮🇳"},
|
| 81 |
+
"ne": {"name": "Nepali", "family": "Indo-European", "region": "South Asia", "greeting": "नमस्ते!", "flag": "🇳🇵"},
|
| 82 |
+
"ta": {"name": "Tamil", "family": "Dravidian", "region": "South Asia", "greeting": "வணக்கம்!", "flag": "🇮🇳"},
|
| 83 |
+
"te": {"name": "Telugu", "family": "Dravidian", "region": "South Asia", "greeting": "హలో!", "flag": "🇮🇳"},
|
| 84 |
+
# ── East Asia ──
|
| 85 |
+
"zh": {"name": "Chinese", "family": "Sino-Tibetan", "region": "East Asia", "greeting": "你好!", "flag": "🇨🇳"},
|
| 86 |
+
"ja": {"name": "Japanese", "family": "Japonic", "region": "East Asia", "greeting": "こんにちは!", "flag": "🇯🇵"},
|
| 87 |
+
"ko": {"name": "Korean", "family": "Koreanic", "region": "East Asia", "greeting": "안녕하세요!", "flag": "🇰🇷"},
|
| 88 |
+
# ── Southeast Asia ──
|
| 89 |
+
"id": {"name": "Indonesian", "family": "Austronesian", "region": "Southeast Asia", "greeting": "Halo!", "flag": "🇮🇩"},
|
| 90 |
+
"ms": {"name": "Malay", "family": "Austronesian", "region": "Southeast Asia", "greeting": "Hai!", "flag": "🇲🇾"},
|
| 91 |
+
"tl": {"name": "Tagalog", "family": "Austronesian", "region": "Southeast Asia", "greeting": "Kamusta!", "flag": "🇵🇭"},
|
| 92 |
+
"jv": {"name": "Javanese", "family": "Austronesian", "region": "Southeast Asia", "greeting": "Halo!", "flag": "🇮🇩"},
|
| 93 |
+
"vi": {"name": "Vietnamese", "family": "Austroasiatic", "region": "Southeast Asia", "greeting": "Xin chào!", "flag": "🇻🇳"},
|
| 94 |
+
"km": {"name": "Khmer", "family": "Austroasiatic", "region": "Southeast Asia", "greeting": "សួស្តី!", "flag": "🇰🇭"},
|
| 95 |
+
"th": {"name": "Thai", "family": "Kra-Dai", "region": "Southeast Asia", "greeting": "สวัสดี!", "flag": "🇹🇭"},
|
| 96 |
+
"lo": {"name": "Lao", "family": "Kra-Dai", "region": "Southeast Asia", "greeting": "ສະບາຍດີ!", "flag": "🇱🇦"},
|
| 97 |
+
"my": {"name": "Burmese", "family": "Sino-Tibetan", "region": "Southeast Asia", "greeting": "မင်္ဂလာပါ!", "flag": "🇲🇲"},
|
| 98 |
+
# ── Africa ──
|
| 99 |
+
"am": {"name": "Amharic", "family": "Afroasiatic", "region": "Africa", "greeting": "ሰላም!", "flag": "🇪🇹"},
|
| 100 |
+
"ha": {"name": "Hausa", "family": "Afroasiatic", "region": "Africa", "greeting": "Sannu!", "flag": "🇳🇬"},
|
| 101 |
+
"sw": {"name": "Swahili", "family": "Niger-Congo", "region": "Africa", "greeting": "Habari!", "flag": "🇰🇪"},
|
| 102 |
+
"ig": {"name": "Igbo", "family": "Niger-Congo", "region": "Africa", "greeting": "Nnọọ!", "flag": "🇳🇬"},
|
| 103 |
+
"yo": {"name": "Yoruba", "family": "Niger-Congo", "region": "Africa", "greeting": "Ẹ kú!", "flag": "🇳🇬"},
|
| 104 |
+
"zu": {"name": "Zulu", "family": "Niger-Congo", "region": "Africa", "greeting": "Sawubona!", "flag": "🇿🇦"},
|
| 105 |
+
"xh": {"name": "Xhosa", "family": "Niger-Congo", "region": "Africa", "greeting": "Molo!", "flag": "🇿🇦"},
|
| 106 |
+
"sn": {"name": "Shona", "family": "Niger-Congo", "region": "Africa", "greeting": "Mhoro!", "flag": "🇿🇼"},
|
| 107 |
+
"wo": {"name": "Wolof", "family": "Niger-Congo", "region": "Africa", "greeting": "Nanga def!", "flag": "🇸🇳"},
|
| 108 |
+
"mg": {"name": "Malagasy", "family": "Austronesian", "region": "Africa", "greeting": "Manao ahoana!", "flag": "🇲🇬"},
|
| 109 |
}
|
| 110 |
|
| 111 |
+
# Region display order
|
| 112 |
+
REGION_ORDER = ["Europe", "Middle East", "South Asia", "East Asia", "Southeast Asia", "Africa"]
|
| 113 |
+
|
| 114 |
+
POLYGLOT_EXAMPLES = [
|
| 115 |
+
["Tell me about the history of mathematics", "en"],
|
| 116 |
+
["Explica la teoría de la relatividad en términos sencillos", "es"],
|
| 117 |
+
["गुरुत्वाकर्षण क्या है? सरल शब्दों में समझाइए।", "hi"],
|
| 118 |
+
["量子计算的基本原理是什么?", "zh"],
|
| 119 |
+
["اكتب قصة قصيرة عن رحلة إلى الفضاء", "ar"],
|
| 120 |
+
["Eleza umuhimu wa maji kwa maisha", "sw"],
|
| 121 |
+
["Yapay zeka dünyayı nasıl değiştirecek?", "tr"],
|
| 122 |
+
["日本の四季について教えてください", "ja"],
|
| 123 |
+
["Jelaskan tentang perubahan iklim", "id"],
|
| 124 |
+
["సౌర వ్యవస్థ గురించి చెప్పండి", "te"],
|
| 125 |
+
["Parlez-moi de la Révolution française", "fr"],
|
| 126 |
+
["Erkläre mir die Quantenmechanik", "de"],
|
| 127 |
+
]
|
| 128 |
+
|
| 129 |
+
# Plotly colors matching the Cohere+Wayy blend
|
| 130 |
+
CORAL = "#E8553D"
|
| 131 |
+
CORAL_LIGHT = "#FF8A73"
|
| 132 |
+
BLUE = "#4C6EE6"
|
| 133 |
+
GREEN = "#2AAA5B"
|
| 134 |
+
AMBER = "#E8943D"
|
| 135 |
+
INDIGO = "#4338ca"
|
| 136 |
+
TEXT_PRIMARY = "#1A1A2E"
|
| 137 |
+
TEXT_SECONDARY = "#555570"
|
| 138 |
|
| 139 |
|
| 140 |
def load_benchmarks() -> dict[str, Any]:
|
|
|
|
| 141 |
if BENCHMARK_PATH.exists():
|
| 142 |
with open(BENCHMARK_PATH) as f:
|
| 143 |
return json.load(f)
|
|
|
|
| 147 |
BENCHMARKS = load_benchmarks()
|
| 148 |
|
| 149 |
# ---------------------------------------------------------------------------
|
| 150 |
+
# Model
|
| 151 |
# ---------------------------------------------------------------------------
|
| 152 |
|
| 153 |
model = None
|
| 154 |
tokenizer = None
|
| 155 |
+
MODEL_LOADED = False
|
| 156 |
+
VOCAB_MAPPING = None # old_token_id -> new_token_id mapping for pruned vocab
|
| 157 |
|
| 158 |
|
| 159 |
+
def load_model():
|
| 160 |
+
"""Load the pruned Aetheris model from HuggingFace."""
|
| 161 |
+
global model, tokenizer, MODEL_LOADED, VOCAB_MAPPING
|
| 162 |
+
if MODEL_LOADED:
|
| 163 |
return model, tokenizer
|
| 164 |
|
| 165 |
+
REPO_ID = "wayyresearch/aetheris"
|
| 166 |
try:
|
| 167 |
+
from huggingface_hub import snapshot_download
|
| 168 |
+
from transformers import AutoTokenizer
|
| 169 |
+
|
| 170 |
+
print(f"Downloading {REPO_ID}...")
|
| 171 |
+
local_dir = snapshot_download(REPO_ID)
|
| 172 |
+
print(f"Downloaded to {local_dir}")
|
| 173 |
+
|
| 174 |
+
# Add model code to path so we can import aetheris
|
| 175 |
+
sys.path.insert(0, local_dir)
|
| 176 |
+
from aetheris.config import AetherisConfig
|
| 177 |
+
from aetheris.model import HybridMambaMoE
|
| 178 |
+
|
| 179 |
+
# Load config and create model
|
| 180 |
+
config = AetherisConfig.from_yaml(os.path.join(local_dir, "config.yaml"))
|
| 181 |
+
model = HybridMambaMoE(config)
|
| 182 |
+
|
| 183 |
+
# Load weights
|
| 184 |
+
sd = torch.load(
|
| 185 |
+
os.path.join(local_dir, "pytorch_model.pt"),
|
| 186 |
+
map_location="cpu",
|
| 187 |
+
weights_only=True,
|
| 188 |
)
|
| 189 |
+
model.load_state_dict(sd)
|
| 190 |
+
model.eval()
|
| 191 |
+
print(f"Model loaded: {sum(p.numel() for p in model.parameters())/1e6:.0f}M params")
|
| 192 |
+
|
| 193 |
+
# Load vocab mapping (old tokenizer IDs -> pruned model IDs)
|
| 194 |
+
mapping_path = os.path.join(local_dir, "vocab_mapping.json")
|
| 195 |
+
if os.path.exists(mapping_path):
|
| 196 |
+
with open(mapping_path) as f:
|
| 197 |
+
mapping_data = json.load(f)
|
| 198 |
+
keep_list = mapping_data["keep_list"]
|
| 199 |
+
VOCAB_MAPPING = {old_id: new_id for new_id, old_id in enumerate(keep_list)}
|
| 200 |
+
print(f"Vocab mapping loaded: {len(VOCAB_MAPPING)} tokens")
|
| 201 |
+
else:
|
| 202 |
+
VOCAB_MAPPING = None
|
| 203 |
+
print("No vocab mapping found, using direct token IDs")
|
| 204 |
+
|
| 205 |
+
# Load tokenizer (Aya tokenizer from teacher)
|
| 206 |
+
tokenizer = AutoTokenizer.from_pretrained("CohereForAI/aya-expanse-8b")
|
| 207 |
+
print(f"Tokenizer loaded: {len(tokenizer)} tokens")
|
| 208 |
+
|
| 209 |
+
MODEL_LOADED = True
|
| 210 |
return model, tokenizer
|
| 211 |
+
|
| 212 |
except Exception as e:
|
| 213 |
+
print(f"Could not load model: {e}")
|
| 214 |
+
import traceback
|
| 215 |
+
traceback.print_exc()
|
| 216 |
+
MODEL_LOADED = True # Don't retry
|
| 217 |
return None, None
|
| 218 |
|
| 219 |
|
| 220 |
# ---------------------------------------------------------------------------
|
| 221 |
+
# Demo mode — scripted multilingual responses while model trains
|
| 222 |
# ---------------------------------------------------------------------------
|
| 223 |
|
| 224 |
+
# Each entry: (native_response, english_translation)
|
| 225 |
+
# English entries have translation=None since they're already in English
|
| 226 |
+
DEMO_RESPONSES: dict[str, list[tuple[str, str | None]]] = {
|
| 227 |
+
"ar": [
|
| 228 |
+
("مرحباً! أنا إيثريس، نموذج لغوي متعدد اللغات. أنا قيد التدريب حالياً — سأكون متاحاً بالكامل قريباً!",
|
| 229 |
+
"Hello! I'm Aetheris, a multilingual language model. I'm currently in training — I'll be fully available soon!"),
|
| 230 |
+
("هذا سؤال رائع! حالياً أعمل في وضع تجريبي بينما يتم تدريب نموذجي الكامل. تفضل بزيارة تبويب المعايير لرؤية مقارنة أدائي.",
|
| 231 |
+
"That's a great question! I'm currently working in demo mode while my full model is being trained. Visit the benchmarks tab to see my performance comparison."),
|
| 232 |
+
],
|
| 233 |
+
"zh": [
|
| 234 |
+
("你好!我是 Aetheris,一个多语言混合 Mamba-MoE 模型。我目前正在训练中,很快就会完全上线!",
|
| 235 |
+
"Hello! I'm Aetheris, a multilingual hybrid Mamba-MoE model. I'm currently training and will be fully online soon!"),
|
| 236 |
+
("这是个好问题!我目前处于演示模式。请查看基准测试标签,了解我与 tiny-aya-global 的性能对比。",
|
| 237 |
+
"Good question! I'm currently in demo mode. Check the benchmarks tab to see how I compare with tiny-aya-global."),
|
| 238 |
+
],
|
| 239 |
+
"ja": [
|
| 240 |
+
("こんにちは!Aetherisです。多言語ハイブリッドMamba-MoEモデルです。現在トレーニング中ですが、まもなく完全に利用可能になります!",
|
| 241 |
+
"Hello! I'm Aetheris, a multilingual hybrid Mamba-MoE model. I'm currently training but will be fully available soon!"),
|
| 242 |
+
("素晴らしい質問ですね!現在はデモモードで動作しています。ベンチマークタブで性能比較をご覧ください。",
|
| 243 |
+
"Great question! I'm currently running in demo mode. Check the benchmarks tab for performance comparisons."),
|
| 244 |
+
],
|
| 245 |
+
"hi": [
|
| 246 |
+
("नमस्ते! मैं एथेरिस हूँ, एक बहुभाषी हाइब्रिड Mamba-MoE मॉडल। मैं अभी प्रशिक्षण में हूँ — जल्द ही पूरी तरह उपलब्ध होऊँगा!",
|
| 247 |
+
"Hello! I'm Aetheris, a multilingual hybrid Mamba-MoE model. I'm currently in training — I'll be fully available soon!"),
|
| 248 |
+
("बढ़िया सवाल! मैं अभी डेमो मोड में हूँ। मेरे प्रदर्शन की तुलना देखने के लिए बेंचमार्क टैब देखें।",
|
| 249 |
+
"Great question! I'm in demo mode right now. Check the benchmarks tab to see my performance comparison."),
|
| 250 |
+
],
|
| 251 |
+
"es": [
|
| 252 |
+
("¡Hola! Soy Aetheris, un modelo multilingüe híbrido Mamba-MoE. Estoy en entrenamiento — ¡pronto estaré completamente disponible!",
|
| 253 |
+
"Hello! I'm Aetheris, a hybrid Mamba-MoE multilingual model. I'm in training — I'll be fully available soon!"),
|
| 254 |
+
("¡Buena pregunta! Ahora mismo estoy en modo demo mientras mi modelo completo se entrena. Visita la pestaña de benchmarks para ver cómo me comparo con tiny-aya-global.",
|
| 255 |
+
"Good question! I'm currently in demo mode while my full model trains. Visit the benchmarks tab to see how I compare with tiny-aya-global."),
|
| 256 |
+
],
|
| 257 |
+
"fr": [
|
| 258 |
+
("Bonjour ! Je suis Aetheris, un modèle multilingue hybride Mamba-MoE. Je suis en cours d'entraînement — je serai bientôt pleinement disponible !",
|
| 259 |
+
"Hello! I'm Aetheris, a hybrid Mamba-MoE multilingual model. I'm currently training — I'll be fully available soon!"),
|
| 260 |
+
("Excellente question ! Je suis actuellement en mode démo. Consultez l'onglet benchmarks pour voir mes performances comparées à tiny-aya-global.",
|
| 261 |
+
"Excellent question! I'm currently in demo mode. Check the benchmarks tab to see my performance compared to tiny-aya-global."),
|
| 262 |
+
],
|
| 263 |
+
"de": [
|
| 264 |
+
("Hallo! Ich bin Aetheris, ein mehrsprachiges Hybrid-Mamba-MoE-Modell. Ich werde gerade trainiert — bald bin ich vollständig verfügbar!",
|
| 265 |
+
"Hello! I'm Aetheris, a multilingual hybrid Mamba-MoE model. I'm currently being trained — I'll be fully available soon!"),
|
| 266 |
+
("Gute Frage! Ich bin gerade im Demo-Modus. Schauen Sie sich den Benchmark-Tab an, um meine Leistung im Vergleich zu tiny-aya-global zu sehen.",
|
| 267 |
+
"Good question! I'm currently in demo mode. Check the benchmarks tab to see my performance compared to tiny-aya-global."),
|
| 268 |
+
],
|
| 269 |
+
"ko": [
|
| 270 |
+
("안녕하세요! 저는 Aetheris입니다. 다국어 하이브리드 Mamba-MoE 모델이에요. 현재 훈련 중이며 곧 완전히 사용 가능해질 거예요!",
|
| 271 |
+
"Hello! I'm Aetheris, a multilingual hybrid Mamba-MoE model. I'm currently training and will be fully available soon!"),
|
| 272 |
+
("좋은 질문이에요! 지금은 데모 모드로 작동 중입니다. 벤치마크 탭에서 성능 비교를 확인해 보세요.",
|
| 273 |
+
"Good question! I'm currently running in demo mode. Check the benchmarks tab for performance comparisons."),
|
| 274 |
+
],
|
| 275 |
+
"tr": [
|
| 276 |
+
("Merhaba! Ben Aetheris, çok dilli hibrit bir Mamba-MoE modeliyim. Şu anda eğitim aşamasındayım — yakında tamamen hazır olacağım!",
|
| 277 |
+
"Hello! I'm Aetheris, a multilingual hybrid Mamba-MoE model. I'm currently in training — I'll be fully ready soon!"),
|
| 278 |
+
("Harika soru! Şu anda demo modundayım. Benchmark sekmesine göz atarak performans karşılaştırmasını görebilirsiniz.",
|
| 279 |
+
"Great question! I'm currently in demo mode. Check the benchmarks tab to see performance comparisons."),
|
| 280 |
+
],
|
| 281 |
+
"sw": [
|
| 282 |
+
("Habari! Mimi ni Aetheris, modeli ya lugha nyingi ya Mamba-MoE. Ninafunzwa sasa — nitakuwa tayari hivi karibuni!",
|
| 283 |
+
"Hello! I'm Aetheris, a multilingual Mamba-MoE model. I'm currently being trained — I'll be ready soon!"),
|
| 284 |
+
("Swali zuri! Niko katika hali ya maonyesho sasa hivi. Angalia kichupo cha vigezo vya utendaji kuona jinsi ninavyolinganishwa.",
|
| 285 |
+
"Good question! I'm in demo mode right now. Check the benchmarks tab to see how I compare."),
|
| 286 |
+
],
|
| 287 |
+
"id": [
|
| 288 |
+
("Halo! Saya Aetheris, model multibahasa hibrida Mamba-MoE. Saya sedang dalam pelatihan — segera akan tersedia sepenuhnya!",
|
| 289 |
+
"Hello! I'm Aetheris, a hybrid Mamba-MoE multilingual model. I'm currently in training — I'll be fully available soon!"),
|
| 290 |
+
("Pertanyaan bagus! Saat ini saya dalam mode demo. Lihat tab benchmark untuk perbandingan performa saya.",
|
| 291 |
+
"Good question! I'm currently in demo mode. Check the benchmarks tab for my performance comparison."),
|
| 292 |
+
],
|
| 293 |
+
"te": [
|
| 294 |
+
("హలో! నేను ఏథెరిస్, బహుభాషా హైబ్రిడ్ Mamba-MoE మోడల్. ప్రస్తుతం శిక్షణలో ఉన్నాను — త్వరలో పూర్తిగా అందుబాటులో ఉంటాను!",
|
| 295 |
+
"Hello! I'm Aetheris, a multilingual hybrid Mamba-MoE model. I'm currently in training — I'll be fully available soon!"),
|
| 296 |
+
("మంచి ప్రశ్న! ప్రస్తుతం డెమో మోడ్లో ఉన్నాను. బెంచ్మార్క్ ట్యాబ్లో పనితీరు పోలికను చూడండి.",
|
| 297 |
+
"Good question! I'm currently in demo mode. Check the benchmarks tab for performance comparisons."),
|
| 298 |
+
],
|
| 299 |
+
"pt": [
|
| 300 |
+
("Olá! Eu sou o Aetheris, um modelo multilíngue híbrido Mamba-MoE. Estou em treinamento — em breve estarei totalmente disponível!",
|
| 301 |
+
"Hello! I'm Aetheris, a hybrid Mamba-MoE multilingual model. I'm in training — I'll be fully available soon!"),
|
| 302 |
+
("Boa pergunta! Estou no modo demo agora. Confira a aba de benchmarks para ver como me comparo com o tiny-aya-global.",
|
| 303 |
+
"Good question! I'm in demo mode now. Check the benchmarks tab to see how I compare with tiny-aya-global."),
|
| 304 |
+
],
|
| 305 |
+
"ru": [
|
| 306 |
+
("Привет! Я Aetheris, мультиязычная гибридная модель Mamba-MoE. Сейчас я на стадии обучения — скоро буду полностью доступен!",
|
| 307 |
+
"Hello! I'm Aetheris, a multilingual hybrid Mamba-MoE model. I'm currently in training — I'll be fully available soon!"),
|
| 308 |
+
("Отличный вопрос! Сейчас я работаю в демо-режиме. Загляните на вкладку бенчмарков, чтобы увидеть сравнение производительности.",
|
| 309 |
+
"Great question! I'm currently working in demo mode. Check the benchmarks tab to see performance comparisons."),
|
| 310 |
+
],
|
| 311 |
+
"en": [
|
| 312 |
+
("Hello! I'm Aetheris, a hybrid Mamba-MoE multilingual model with 800M parameters, "
|
| 313 |
+
"distilled from CohereLabs' tiny-aya-global (3.35B) with 4.2x compression.\n\n"
|
| 314 |
+
"I'm currently training on RunPod — Stage 1 (CKA layer alignment) is in progress. "
|
| 315 |
+
"Once complete, I'll be fully deployed here for live inference!\n\n"
|
| 316 |
+
"In the meantime, check out the **Benchmarks** tab to see projected performance across "
|
| 317 |
+
"10 languages, throughput comparisons, and our multilingual equity metrics.",
|
| 318 |
+
None),
|
| 319 |
+
|
| 320 |
+
("Great question! Right now I'm in demo mode — my full model is being trained using a "
|
| 321 |
+
"3-stage distillation pipeline:\n\n"
|
| 322 |
+
"1. **Stage 1** — CKA-guided layer alignment (in progress)\n"
|
| 323 |
+
"2. **Stage 2** — KL divergence distillation with per-language tracking\n"
|
| 324 |
+
"3. **Stage 3** — Supervised fine-tuning on aya_collection\n\n"
|
| 325 |
+
"Key innovations: SSM 10x LR boost (compensates 27x gradient imbalance), "
|
| 326 |
+
"SVD split for MoE expert initialization, and DES tracking for multilingual equity.\n\n"
|
| 327 |
+
"Try the Benchmarks tab for detailed comparisons!",
|
| 328 |
+
None),
|
| 329 |
+
|
| 330 |
+
("I appreciate your patience! My architecture is unique — I interleave "
|
| 331 |
+
"**Mamba SSM layers** (for efficient sequence modeling) with **MoE layers** "
|
| 332 |
+
"(for capacity without parameter explosion). This hybrid design achieves "
|
| 333 |
+
"3.1x faster inference and 4x memory reduction compared to the teacher model.\n\n"
|
| 334 |
+
"Once training completes, you'll be able to chat with me in 67 languages right here!",
|
| 335 |
+
None),
|
| 336 |
+
],
|
| 337 |
+
}
|
| 338 |
+
|
| 339 |
+
# Counters for cycling through responses per language
|
| 340 |
+
_demo_counters: dict[str, int] = {}
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
def _detect_lang(text: str) -> str:
|
| 344 |
+
"""Simple script-based language detection for demo mode."""
|
| 345 |
+
import unicodedata
|
| 346 |
+
scripts: dict[str, int] = {}
|
| 347 |
+
for ch in text:
|
| 348 |
+
if ch.isalpha():
|
| 349 |
+
script = unicodedata.script(ch) if hasattr(unicodedata, 'script') else "LATIN"
|
| 350 |
+
scripts[script] = scripts.get(script, 0) + 1
|
| 351 |
+
|
| 352 |
+
# Fallback: check for non-ASCII character ranges
|
| 353 |
+
for ch in text:
|
| 354 |
+
cp = ord(ch)
|
| 355 |
+
if 0x0600 <= cp <= 0x06FF:
|
| 356 |
+
return "ar"
|
| 357 |
+
if 0x4E00 <= cp <= 0x9FFF:
|
| 358 |
+
return "zh"
|
| 359 |
+
if 0x3040 <= cp <= 0x30FF or 0x31F0 <= cp <= 0x31FF:
|
| 360 |
+
return "ja"
|
| 361 |
+
if 0x0900 <= cp <= 0x097F:
|
| 362 |
+
return "hi"
|
| 363 |
+
if 0xAC00 <= cp <= 0xD7AF:
|
| 364 |
+
return "ko"
|
| 365 |
+
if 0x0C00 <= cp <= 0x0C7F:
|
| 366 |
+
return "te"
|
| 367 |
+
if 0x0E00 <= cp <= 0x0E7F:
|
| 368 |
+
return "th"
|
| 369 |
+
if 0x0980 <= cp <= 0x09FF:
|
| 370 |
+
return "bn"
|
| 371 |
+
if 0x0400 <= cp <= 0x04FF:
|
| 372 |
+
return "ru"
|
| 373 |
+
if 0x0A80 <= cp <= 0x0AFF:
|
| 374 |
+
return "gu"
|
| 375 |
+
|
| 376 |
+
# Latin script — check for language-specific keywords
|
| 377 |
+
lower = text.lower()
|
| 378 |
+
if any(w in lower for w in ["¿", "¡", "está", "hola", "qué", "cómo", "también"]):
|
| 379 |
+
return "es"
|
| 380 |
+
if any(w in lower for w in ["bonjour", "merci", "comment", "pourquoi", "français"]):
|
| 381 |
+
return "fr"
|
| 382 |
+
if any(w in lower for w in ["hallo", "danke", "warum", "erklär", "deutsch"]):
|
| 383 |
+
return "de"
|
| 384 |
+
if any(w in lower for w in ["olá", "obrigad", "como", "português"]):
|
| 385 |
+
return "pt"
|
| 386 |
+
if any(w in lower for w in ["merhaba", "nasıl", "nedir", "türk"]):
|
| 387 |
+
return "tr"
|
| 388 |
+
if any(w in lower for w in ["habari", "nini", "kwa", "swahili"]):
|
| 389 |
+
return "sw"
|
| 390 |
+
if any(w in lower for w in ["halo", "apa", "jelaskan", "bagaimana", "indonesia"]):
|
| 391 |
+
return "id"
|
| 392 |
+
|
| 393 |
+
return "en"
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
def _demo_respond(message: str) -> tuple[str, str | None, str]:
|
| 397 |
+
"""Generate a demo response in the detected language.
|
| 398 |
+
|
| 399 |
+
Returns (native_text, english_translation_or_None, lang_code).
|
| 400 |
+
"""
|
| 401 |
+
lang = _detect_lang(message)
|
| 402 |
+
responses = DEMO_RESPONSES.get(lang, DEMO_RESPONSES["en"])
|
| 403 |
+
idx = _demo_counters.get(lang, 0) % len(responses)
|
| 404 |
+
_demo_counters[lang] = idx + 1
|
| 405 |
+
native, translation = responses[idx]
|
| 406 |
+
return native, translation, lang
|
| 407 |
|
| 408 |
+
|
| 409 |
+
# ---------------------------------------------------------------------------
|
| 410 |
+
# Chat logic
|
| 411 |
+
# ---------------------------------------------------------------------------
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
def _map_token_ids(input_ids: list[int]) -> "torch.Tensor":
|
| 415 |
+
"""Map tokenizer IDs to pruned model IDs using vocab mapping."""
|
| 416 |
+
if VOCAB_MAPPING is None:
|
| 417 |
+
return torch.tensor([input_ids])
|
| 418 |
+
mapped = []
|
| 419 |
+
for tid in input_ids:
|
| 420 |
+
if tid in VOCAB_MAPPING:
|
| 421 |
+
mapped.append(VOCAB_MAPPING[tid])
|
| 422 |
+
else:
|
| 423 |
+
# Map unknown tokens to UNK (token 0 in pruned vocab)
|
| 424 |
+
mapped.append(0)
|
| 425 |
+
return torch.tensor([mapped])
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
def _unmap_token_id(model_id: int) -> int:
|
| 429 |
+
"""Map pruned model ID back to original tokenizer ID for decoding."""
|
| 430 |
+
if VOCAB_MAPPING is None:
|
| 431 |
+
return model_id
|
| 432 |
+
# Build reverse mapping lazily
|
| 433 |
+
if not hasattr(_unmap_token_id, "_reverse"):
|
| 434 |
+
_unmap_token_id._reverse = {v: k for k, v in VOCAB_MAPPING.items()}
|
| 435 |
+
return _unmap_token_id._reverse.get(model_id, 0)
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
def _generate_text(m, tok, prompt: str, max_tokens: int, temperature: float) -> str:
|
| 439 |
+
"""Autoregressive text generation with vocab mapping support."""
|
| 440 |
+
import torch
|
| 441 |
+
|
| 442 |
+
# Tokenize
|
| 443 |
+
input_ids = tok.encode(prompt, add_special_tokens=False)
|
| 444 |
+
# Truncate to fit context
|
| 445 |
+
if len(input_ids) > 1024:
|
| 446 |
+
input_ids = input_ids[-1024:]
|
| 447 |
+
|
| 448 |
+
# Map to pruned vocab
|
| 449 |
+
ids_tensor = _map_token_ids(input_ids)
|
| 450 |
+
|
| 451 |
+
generated_ids = []
|
| 452 |
+
eos_id = VOCAB_MAPPING.get(tok.eos_token_id, tok.eos_token_id) if VOCAB_MAPPING else tok.eos_token_id
|
| 453 |
+
|
| 454 |
+
with torch.no_grad():
|
| 455 |
+
for _ in range(max_tokens):
|
| 456 |
+
out = m(ids_tensor)
|
| 457 |
+
logits = out["logits"][:, -1, :] # [1, vocab_size]
|
| 458 |
+
|
| 459 |
+
if temperature > 0:
|
| 460 |
+
logits = logits / temperature
|
| 461 |
+
probs = torch.softmax(logits, dim=-1)
|
| 462 |
+
# Top-p sampling
|
| 463 |
+
sorted_probs, sorted_indices = torch.sort(probs, descending=True)
|
| 464 |
+
cumulative = torch.cumsum(sorted_probs, dim=-1)
|
| 465 |
+
mask = cumulative - sorted_probs > 0.9
|
| 466 |
+
sorted_probs[mask] = 0
|
| 467 |
+
sorted_probs = sorted_probs / sorted_probs.sum()
|
| 468 |
+
next_token = sorted_indices[0, torch.multinomial(sorted_probs[0], 1)]
|
| 469 |
+
else:
|
| 470 |
+
next_token = logits.argmax(dim=-1).squeeze()
|
| 471 |
+
|
| 472 |
+
next_id = next_token.item()
|
| 473 |
+
if next_id == eos_id:
|
| 474 |
+
break
|
| 475 |
+
|
| 476 |
+
generated_ids.append(next_id)
|
| 477 |
+
ids_tensor = torch.cat([ids_tensor, next_token.unsqueeze(0).unsqueeze(0)], dim=1)
|
| 478 |
+
|
| 479 |
+
# Keep context window manageable
|
| 480 |
+
if ids_tensor.shape[1] > 1024:
|
| 481 |
+
ids_tensor = ids_tensor[:, -1024:]
|
| 482 |
+
|
| 483 |
+
# Map back to original tokenizer IDs for decoding
|
| 484 |
+
original_ids = [_unmap_token_id(gid) for gid in generated_ids]
|
| 485 |
+
return tok.decode(original_ids, skip_special_tokens=True).strip()
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
def chat_respond(message: str, history: list[dict], temperature: float, max_tokens: int):
|
| 489 |
+
"""Handle a chat message with multilingual awareness."""
|
| 490 |
+
if not message.strip():
|
| 491 |
+
return history, ""
|
| 492 |
|
| 493 |
m, tok = load_model()
|
| 494 |
|
| 495 |
+
# Model not available
|
| 496 |
if m is None or tok is None:
|
| 497 |
+
response = (
|
| 498 |
+
"Model is currently loading or unavailable. "
|
| 499 |
+
"Please try again in a moment — the ~536M parameter model "
|
| 500 |
+
"may take a minute to download on first load."
|
|
|
|
| 501 |
)
|
| 502 |
+
history.append({"role": "user", "content": message})
|
| 503 |
+
history.append({"role": "assistant", "content": response})
|
| 504 |
+
return history, ""
|
| 505 |
|
| 506 |
try:
|
| 507 |
import torch
|
| 508 |
|
| 509 |
+
# Build conversation context from history
|
| 510 |
+
context_parts = []
|
| 511 |
+
for msg in history[-6:]: # Keep last 6 messages for context
|
| 512 |
+
role = msg["role"]
|
| 513 |
+
content = msg["content"]
|
| 514 |
+
if role == "user":
|
| 515 |
+
context_parts.append(f"User: {content}")
|
| 516 |
+
else:
|
| 517 |
+
clean = content.split("\n\n---\n")[0] if "\n\n---\n" in content else content
|
| 518 |
+
context_parts.append(f"Assistant: {clean}")
|
| 519 |
+
|
| 520 |
+
context_parts.append(f"User: {message}")
|
| 521 |
+
context_parts.append("Assistant:")
|
| 522 |
+
prompt = "\n".join(context_parts)
|
| 523 |
+
|
| 524 |
+
generated = _generate_text(m, tok, prompt, int(max_tokens), float(temperature))
|
| 525 |
+
|
| 526 |
+
# Clean up: stop at next "User:" if present
|
| 527 |
+
if "User:" in generated:
|
| 528 |
+
generated = generated[:generated.index("User:")].strip()
|
| 529 |
+
|
| 530 |
+
# Auto-translate non-English responses for English-speaking users
|
| 531 |
+
lang = _detect_lang(message)
|
| 532 |
+
if lang != "en" and generated:
|
| 533 |
+
try:
|
| 534 |
+
translate_prompt = f"Translate the following to English:\n{generated}\nEnglish translation:"
|
| 535 |
+
translation = _generate_text(m, tok, translate_prompt, int(max_tokens), 0.3)
|
| 536 |
+
if "User:" in translation:
|
| 537 |
+
translation = translation[:translation.index("User:")].strip()
|
| 538 |
+
lang_name = LANGUAGES.get(lang, {}).get("name", lang)
|
| 539 |
+
generated = (
|
| 540 |
+
f"{generated}\n\n"
|
| 541 |
+
f"---\n"
|
| 542 |
+
f"**English translation** ({lang_name}):\n"
|
| 543 |
+
f"*{translation}*"
|
| 544 |
+
)
|
| 545 |
+
except Exception:
|
| 546 |
+
pass # If translation fails, just show the original
|
| 547 |
|
|
|
|
|
|
|
| 548 |
except Exception as e:
|
| 549 |
+
generated = f"Generation error: {e}"
|
| 550 |
+
|
| 551 |
+
history.append({"role": "user", "content": message})
|
| 552 |
+
history.append({"role": "assistant", "content": generated})
|
| 553 |
+
return history, ""
|
| 554 |
+
|
| 555 |
+
|
| 556 |
+
def load_example(example_text: str):
|
| 557 |
+
"""Load an example into the chat input."""
|
| 558 |
+
return example_text
|
| 559 |
|
| 560 |
|
| 561 |
# ---------------------------------------------------------------------------
|
| 562 |
+
# Benchmark visualizations (light theme)
|
| 563 |
# ---------------------------------------------------------------------------
|
| 564 |
|
| 565 |
+
PLOT_LAYOUT = dict(
|
| 566 |
+
template="plotly_white",
|
| 567 |
+
font=dict(family="Inter, Outfit, Helvetica Neue, sans-serif", color=TEXT_PRIMARY),
|
| 568 |
+
paper_bgcolor="rgba(0,0,0,0)",
|
| 569 |
+
plot_bgcolor="rgba(255,251,247,0.5)",
|
| 570 |
+
margin=dict(t=60, b=50, l=60, r=30),
|
| 571 |
+
)
|
| 572 |
+
|
| 573 |
|
| 574 |
def make_model_info_table() -> pd.DataFrame:
|
|
|
|
| 575 |
info = BENCHMARKS.get("model_info", {})
|
| 576 |
teacher = info.get("teacher", {})
|
| 577 |
student = info.get("student", {})
|
|
|
|
| 578 |
rows = [
|
| 579 |
+
("Organization", "CohereLabs", "Wayy Research"),
|
| 580 |
+
("Model", "tiny-aya-global", "Aetheris"),
|
| 581 |
("Parameters", f"{teacher.get('params_m', 3350)}M", f"{student.get('params_m', 800)}M"),
|
| 582 |
+
("Architecture", teacher.get("architecture", "Transformer (Dense)"), "Hybrid Mamba-MoE"),
|
| 583 |
("Layers", str(teacher.get("layers", 36)), str(student.get("layers", 24))),
|
| 584 |
+
("Hidden Dim", str(teacher.get("hidden_dim", 2048)), str(student.get("hidden_dim", 1024))),
|
| 585 |
+
("Attention", "GQA (16 query / 4 kv heads)", "SSM (Mamba) + MoE routing"),
|
| 586 |
("Experts", "N/A (dense)", f"{student.get('num_experts', 4)} (top-{student.get('top_k', 1)})"),
|
| 587 |
+
("Vocab", f"{teacher.get('vocab_size', 262144):,}", f"{student.get('vocab_size', 262144):,}"),
|
| 588 |
+
("Languages", "67", "67 (inherited)"),
|
|
|
|
| 589 |
("Compression", "1.0x (baseline)", f"{student.get('compression_ratio', 4.2)}x"),
|
| 590 |
]
|
| 591 |
+
return pd.DataFrame(rows, columns=["", "Teacher", "Student (Aetheris)"])
|
|
|
|
| 592 |
|
| 593 |
|
| 594 |
def make_benchmark_chart(benchmark_key: str, title: str) -> go.Figure:
|
|
|
|
| 595 |
data = BENCHMARKS.get(benchmark_key, {})
|
| 596 |
langs = data.get("languages", {})
|
| 597 |
avg = data.get("average", {})
|
| 598 |
|
| 599 |
lang_codes = list(langs.keys())
|
| 600 |
+
lang_names = [LANGUAGES.get(lc, {}).get("name", lc) for lc in lang_codes]
|
| 601 |
teacher_scores = [langs[lc]["teacher"] for lc in lang_codes]
|
| 602 |
student_scores = [langs[lc]["student"] for lc in lang_codes]
|
| 603 |
|
| 604 |
+
lang_names.append("Average")
|
|
|
|
| 605 |
teacher_scores.append(avg.get("teacher", 0))
|
| 606 |
student_scores.append(avg.get("student", 0))
|
| 607 |
|
| 608 |
fig = go.Figure()
|
| 609 |
fig.add_trace(go.Bar(
|
| 610 |
name="Teacher (tiny-aya-global)",
|
| 611 |
+
x=lang_names, y=teacher_scores,
|
| 612 |
+
marker_color=BLUE, marker_line_width=0,
|
|
|
|
| 613 |
text=[f"{s:.1f}" for s in teacher_scores],
|
| 614 |
+
textposition="outside", textfont=dict(size=10),
|
|
|
|
| 615 |
))
|
| 616 |
fig.add_trace(go.Bar(
|
| 617 |
name="Student (Aetheris)",
|
| 618 |
+
x=lang_names, y=student_scores,
|
| 619 |
+
marker_color=CORAL, marker_line_width=0,
|
|
|
|
| 620 |
text=[f"{s:.1f}" for s in student_scores],
|
| 621 |
+
textposition="outside", textfont=dict(size=10),
|
|
|
|
| 622 |
))
|
| 623 |
|
|
|
|
| 624 |
fig.update_layout(
|
| 625 |
+
**PLOT_LAYOUT,
|
| 626 |
+
title=dict(text=title, font=dict(size=16, color=TEXT_PRIMARY)),
|
| 627 |
xaxis_title="Language",
|
| 628 |
+
yaxis_title=data.get("metric", "Accuracy (%)"),
|
| 629 |
barmode="group",
|
| 630 |
+
height=420,
|
| 631 |
+
legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1,
|
| 632 |
+
bgcolor="rgba(255,251,247,0.8)"),
|
|
|
|
| 633 |
)
|
|
|
|
| 634 |
return fig
|
| 635 |
|
| 636 |
|
| 637 |
def make_retention_chart() -> go.Figure:
|
|
|
|
| 638 |
mgsm = BENCHMARKS.get("mgsm", {}).get("languages", {})
|
| 639 |
xcopa = BENCHMARKS.get("xcopa", {}).get("languages", {})
|
| 640 |
|
| 641 |
lang_codes = list(mgsm.keys())
|
| 642 |
+
lang_names = [LANGUAGES.get(lc, {}).get("name", lc) for lc in lang_codes]
|
| 643 |
|
| 644 |
+
mgsm_ret = [mgsm[lc]["student"] / mgsm[lc]["teacher"] * 100 if mgsm[lc]["teacher"] > 0 else 0 for lc in lang_codes]
|
| 645 |
+
xcopa_ret = [xcopa[lc]["student"] / xcopa[lc]["teacher"] * 100 if lc in xcopa and xcopa[lc]["teacher"] > 0 else 0 for lc in lang_codes]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 646 |
|
| 647 |
fig = go.Figure()
|
| 648 |
+
fig.add_trace(go.Scatter(x=lang_names, y=mgsm_ret, mode="lines+markers+text", name="mGSM",
|
| 649 |
+
line=dict(color=BLUE, width=2.5), marker=dict(size=8),
|
| 650 |
+
text=[f"{r:.0f}%" for r in mgsm_ret], textposition="top center", textfont=dict(size=9)))
|
| 651 |
+
fig.add_trace(go.Scatter(x=lang_names, y=xcopa_ret, mode="lines+markers+text", name="XCOPA",
|
| 652 |
+
line=dict(color=CORAL, width=2.5), marker=dict(size=8),
|
| 653 |
+
text=[f"{r:.0f}%" for r in xcopa_ret], textposition="bottom center", textfont=dict(size=9)))
|
| 654 |
+
fig.add_hline(y=80, line_dash="dash", line_color=AMBER,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 655 |
annotation_text="80% retention target", annotation_position="top right")
|
| 656 |
|
| 657 |
+
fig.update_layout(**PLOT_LAYOUT, title=dict(text="Quality Retention: Student / Teacher (%)", font=dict(size=16)),
|
| 658 |
+
yaxis_title="Retention (%)", xaxis_title="Language", height=380,
|
| 659 |
+
yaxis=dict(range=[60, 105]),
|
| 660 |
+
legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 661 |
return fig
|
| 662 |
|
| 663 |
|
| 664 |
def make_throughput_chart() -> go.Figure:
|
|
|
|
| 665 |
tp = BENCHMARKS.get("throughput", {})
|
| 666 |
teacher = tp.get("teacher", {})
|
| 667 |
student = tp.get("student", {})
|
| 668 |
|
| 669 |
+
metrics = ["Tokens/sec", "TTFT (ms)", "Memory (MB)"]
|
| 670 |
+
teacher_vals = [teacher.get("tokens_per_sec", 0), teacher.get("ttft_ms", 0), teacher.get("memory_mb", 0)]
|
| 671 |
+
student_vals = [student.get("tokens_per_sec", 0), student.get("ttft_ms", 0), student.get("memory_mb", 0)]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 672 |
|
| 673 |
+
fig = make_subplots(rows=1, cols=3, subplot_titles=metrics, horizontal_spacing=0.1)
|
| 674 |
+
for i, (m, tv, sv) in enumerate(zip(metrics, teacher_vals, student_vals)):
|
| 675 |
+
fig.add_trace(go.Bar(x=["Teacher", "Student"], y=[tv, sv],
|
| 676 |
+
marker_color=[BLUE, CORAL], text=[str(tv), str(sv)],
|
| 677 |
+
textposition="outside", showlegend=False), row=1, col=i + 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 678 |
|
| 679 |
+
fig.update_layout(**PLOT_LAYOUT, height=350,
|
| 680 |
+
title=dict(text="Inference Performance", font=dict(size=16)))
|
| 681 |
return fig
|
| 682 |
|
| 683 |
|
| 684 |
def make_equity_chart() -> go.Figure:
|
|
|
|
| 685 |
equity = BENCHMARKS.get("equity", {})
|
| 686 |
families = equity.get("language_families", {})
|
| 687 |
|
|
|
|
| 689 |
des_vals = [families[n]["des"] for n in names]
|
| 690 |
retention_vals = [families[n]["avg_retention"] * 100 for n in names]
|
| 691 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 692 |
sorted_idx = np.argsort(des_vals)[::-1]
|
| 693 |
sorted_names = [names[i] for i in sorted_idx]
|
| 694 |
sorted_des = [des_vals[i] for i in sorted_idx]
|
| 695 |
|
| 696 |
+
colors = ["#ef4444" if d > 0.4 else AMBER if d > 0.3 else GREEN for d in sorted_des]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 697 |
|
| 698 |
+
fig = go.Figure()
|
| 699 |
+
fig.add_trace(go.Bar(x=sorted_des, y=sorted_names, orientation="h", marker_color=colors,
|
| 700 |
+
text=[f"{d:.2f}" for d in sorted_des], textposition="outside"))
|
| 701 |
+
fig.update_layout(**PLOT_LAYOUT, height=350,
|
| 702 |
+
title=dict(text="Degradation Equity Score by Language Family", font=dict(size=16)),
|
| 703 |
+
xaxis_title="DES (lower = more equitable)", yaxis=dict(autorange="reversed"))
|
| 704 |
return fig
|
| 705 |
|
| 706 |
|
| 707 |
def make_training_chart() -> go.Figure:
|
|
|
|
| 708 |
progress = BENCHMARKS.get("training_progress", {})
|
| 709 |
steps = progress.get("steps", [])
|
| 710 |
kl = progress.get("kl_divergence", [])
|
| 711 |
cka = progress.get("cka_alignment", [])
|
| 712 |
loss = progress.get("total_loss", [])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 713 |
|
| 714 |
+
fig = make_subplots(rows=1, cols=3, subplot_titles=["KL Divergence", "CKA Alignment", "Total Loss"],
|
| 715 |
+
horizontal_spacing=0.08)
|
| 716 |
+
fig.add_trace(go.Scatter(x=steps, y=kl, mode="lines+markers", line=dict(color=CORAL, width=2),
|
| 717 |
+
marker=dict(size=5), showlegend=False), row=1, col=1)
|
| 718 |
+
fig.add_trace(go.Scatter(x=steps, y=cka, mode="lines+markers", line=dict(color=GREEN, width=2),
|
| 719 |
+
marker=dict(size=5), showlegend=False), row=1, col=2)
|
| 720 |
+
fig.add_trace(go.Scatter(x=steps, y=loss, mode="lines+markers", line=dict(color=BLUE, width=2),
|
| 721 |
+
marker=dict(size=5), showlegend=False), row=1, col=3)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 722 |
|
| 723 |
+
fig.add_hline(y=0.75, line_dash="dash", line_color=AMBER, annotation_text="target", row=1, col=2)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 724 |
|
| 725 |
+
fig.update_layout(**PLOT_LAYOUT, height=320,
|
| 726 |
+
title=dict(text="Training Progress", font=dict(size=16)))
|
| 727 |
+
fig.update_xaxes(title_text="Steps", row=1, col=2)
|
| 728 |
return fig
|
| 729 |
|
| 730 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 731 |
# ---------------------------------------------------------------------------
|
| 732 |
+
# CSS — Cohere warm palette + Wayy Research accents
|
| 733 |
# ---------------------------------------------------------------------------
|
| 734 |
|
| 735 |
CUSTOM_CSS = """
|
| 736 |
+
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap');
|
| 737 |
+
|
| 738 |
+
/* === Global overrides === */
|
| 739 |
+
.gradio-container {
|
| 740 |
+
background: #FFFBF7 !important;
|
| 741 |
+
font-family: 'Inter', 'Outfit', 'Helvetica Neue', sans-serif !important;
|
| 742 |
+
max-width: 1200px !important;
|
| 743 |
}
|
| 744 |
+
|
| 745 |
+
/* === Header banner === */
|
| 746 |
+
.hero-banner {
|
| 747 |
+
background: linear-gradient(135deg, #FFFBF7 0%, #FEF0E4 40%, #EEF2FF 100%);
|
| 748 |
+
border: 1px solid rgba(0,0,0,0.06);
|
| 749 |
+
border-radius: 16px;
|
| 750 |
+
padding: 40px 48px;
|
| 751 |
+
margin-bottom: 8px;
|
| 752 |
+
position: relative;
|
| 753 |
+
overflow: hidden;
|
| 754 |
+
}
|
| 755 |
+
.hero-banner::before {
|
| 756 |
+
content: '';
|
| 757 |
+
position: absolute;
|
| 758 |
+
top: 0; left: 0; right: 0;
|
| 759 |
+
height: 4px;
|
| 760 |
+
background: linear-gradient(90deg, #E8553D 0%, #FF8A73 25%, #4C6EE6 50%, #4338ca 75%, #2AAA5B 100%);
|
| 761 |
+
}
|
| 762 |
+
.hero-banner h1 {
|
| 763 |
+
font-size: 2.8em !important;
|
| 764 |
+
font-weight: 300 !important;
|
| 765 |
+
letter-spacing: -0.03em !important;
|
| 766 |
+
color: #1A1A2E !important;
|
| 767 |
margin: 0 0 4px 0 !important;
|
| 768 |
+
line-height: 1.1 !important;
|
| 769 |
+
}
|
| 770 |
+
.hero-banner h1 .coral { color: #E8553D; font-weight: 400; }
|
| 771 |
+
.hero-banner .collab-line {
|
| 772 |
+
font-size: 0.85em;
|
| 773 |
+
font-weight: 500;
|
| 774 |
+
letter-spacing: 0.12em;
|
| 775 |
+
text-transform: uppercase;
|
| 776 |
+
color: #8E8EA0;
|
| 777 |
+
margin-bottom: 16px;
|
| 778 |
+
}
|
| 779 |
+
.hero-banner .collab-line .dot {
|
| 780 |
+
display: inline-block;
|
| 781 |
+
width: 5px; height: 5px;
|
| 782 |
+
border-radius: 50%;
|
| 783 |
+
background: #FF8A73;
|
| 784 |
+
margin: 0 12px;
|
| 785 |
+
vertical-align: middle;
|
| 786 |
+
}
|
| 787 |
+
.hero-banner .subtitle {
|
| 788 |
+
font-size: 1.05em;
|
| 789 |
+
color: #555570;
|
| 790 |
+
line-height: 1.6;
|
| 791 |
+
max-width: 700px;
|
| 792 |
+
}
|
| 793 |
+
|
| 794 |
+
/* === Stat pills === */
|
| 795 |
+
.stat-row {
|
| 796 |
+
display: flex;
|
| 797 |
+
gap: 12px;
|
| 798 |
+
flex-wrap: wrap;
|
| 799 |
+
margin: 20px 0 8px 0;
|
| 800 |
}
|
| 801 |
+
.stat-pill {
|
| 802 |
+
background: white;
|
| 803 |
+
border: 1px solid rgba(0,0,0,0.06);
|
| 804 |
+
border-radius: 999px;
|
| 805 |
+
padding: 8px 20px;
|
| 806 |
+
font-size: 0.82em;
|
| 807 |
+
font-weight: 500;
|
| 808 |
+
color: #555570;
|
| 809 |
+
box-shadow: 0 1px 3px rgba(0,0,0,0.04);
|
| 810 |
+
transition: all 0.3s ease;
|
| 811 |
}
|
| 812 |
+
.stat-pill:hover {
|
| 813 |
+
transform: translateY(-2px);
|
| 814 |
+
box-shadow: 0 4px 12px rgba(232,85,61,0.12);
|
| 815 |
+
border-color: rgba(232,85,61,0.2);
|
| 816 |
+
}
|
| 817 |
+
.stat-pill .num { color: #E8553D; font-weight: 700; }
|
| 818 |
+
.stat-pill.blue .num { color: #4C6EE6; }
|
| 819 |
+
.stat-pill.green .num { color: #2AAA5B; }
|
| 820 |
+
.stat-pill.amber .num { color: #E8943D; }
|
| 821 |
+
|
| 822 |
+
/* === Language showcase grid === */
|
| 823 |
+
.lang-grid {
|
| 824 |
+
display: grid;
|
| 825 |
+
grid-template-columns: repeat(auto-fill, minmax(130px, 1fr));
|
| 826 |
+
gap: 8px;
|
| 827 |
+
margin: 16px 0;
|
| 828 |
+
}
|
| 829 |
+
.lang-card {
|
| 830 |
+
background: white;
|
| 831 |
+
border: 1px solid rgba(0,0,0,0.06);
|
| 832 |
+
border-radius: 12px;
|
| 833 |
+
padding: 14px 10px;
|
| 834 |
text-align: center;
|
| 835 |
+
cursor: pointer;
|
| 836 |
+
transition: all 0.3s ease;
|
| 837 |
+
box-shadow: 0 1px 3px rgba(0,0,0,0.04);
|
| 838 |
+
}
|
| 839 |
+
.lang-card:hover {
|
| 840 |
+
transform: translateY(-3px);
|
| 841 |
+
box-shadow: 0 8px 24px rgba(232,85,61,0.1);
|
| 842 |
+
border-color: rgba(232,85,61,0.2);
|
| 843 |
+
}
|
| 844 |
+
.lang-card .greeting { font-size: 1.3em; margin-bottom: 4px; }
|
| 845 |
+
.lang-card .name { font-size: 0.78em; font-weight: 600; color: #1A1A2E; }
|
| 846 |
+
.lang-card .family { font-size: 0.6em; color: #8E8EA0; text-transform: uppercase; letter-spacing: 0.08em; }
|
| 847 |
+
|
| 848 |
+
/* === Region grouping === */
|
| 849 |
+
.region-label {
|
| 850 |
+
font-size: 0.72em;
|
| 851 |
+
font-weight: 600;
|
| 852 |
+
letter-spacing: 0.12em;
|
| 853 |
+
text-transform: uppercase;
|
| 854 |
+
color: #E8553D;
|
| 855 |
+
margin: 16px 0 6px 0;
|
| 856 |
+
padding-bottom: 4px;
|
| 857 |
+
border-bottom: 1px solid rgba(232,85,61,0.15);
|
| 858 |
}
|
| 859 |
+
|
| 860 |
+
/* === Chat styling === */
|
| 861 |
+
.chat-container .message {
|
| 862 |
+
border-radius: 12px !important;
|
| 863 |
}
|
| 864 |
+
|
| 865 |
+
/* === Benchmark notice === */
|
| 866 |
+
.bench-notice {
|
| 867 |
+
background: linear-gradient(135deg, #FEF0E4 0%, #EEF2FF 100%);
|
| 868 |
+
border: 1px solid rgba(232,85,61,0.15);
|
| 869 |
+
border-left: 3px solid #E8553D;
|
| 870 |
+
border-radius: 0 12px 12px 0;
|
| 871 |
+
padding: 14px 20px;
|
| 872 |
+
color: #555570;
|
| 873 |
+
font-size: 0.88em;
|
| 874 |
+
margin-bottom: 16px;
|
| 875 |
+
}
|
| 876 |
+
|
| 877 |
+
/* === Section headers === */
|
| 878 |
+
.section-label {
|
| 879 |
+
font-size: 0.7em;
|
| 880 |
+
font-weight: 600;
|
| 881 |
+
letter-spacing: 0.18em;
|
| 882 |
+
text-transform: uppercase;
|
| 883 |
+
color: #E8553D;
|
| 884 |
+
margin-bottom: 4px;
|
| 885 |
}
|
| 886 |
+
|
| 887 |
+
/* === Footer === */
|
| 888 |
+
.footer-section {
|
| 889 |
+
background: #FFF6EE;
|
| 890 |
+
border: 1px solid rgba(0,0,0,0.06);
|
| 891 |
+
border-radius: 12px;
|
| 892 |
+
padding: 24px 32px;
|
| 893 |
+
text-align: center;
|
| 894 |
+
margin-top: 16px;
|
| 895 |
+
}
|
| 896 |
+
.footer-section .collab-title {
|
| 897 |
+
font-size: 1.3em;
|
| 898 |
+
font-weight: 300;
|
| 899 |
+
color: #1A1A2E;
|
| 900 |
+
margin-bottom: 4px;
|
| 901 |
+
}
|
| 902 |
+
.footer-section .collab-title .coral { color: #E8553D; font-weight: 400; }
|
| 903 |
+
.footer-section .collab-title .indigo { color: #4338ca; font-weight: 400; }
|
| 904 |
+
.footer-section .tagline {
|
| 905 |
+
font-style: italic;
|
| 906 |
+
color: #555570;
|
| 907 |
+
font-size: 0.92em;
|
| 908 |
+
}
|
| 909 |
+
.footer-section .links {
|
| 910 |
+
font-size: 0.75em;
|
| 911 |
+
color: #8E8EA0;
|
| 912 |
+
margin-top: 8px;
|
| 913 |
+
}
|
| 914 |
+
|
| 915 |
+
/* === Tabs === */
|
| 916 |
+
.tab-nav button {
|
| 917 |
+
font-weight: 500 !important;
|
| 918 |
+
letter-spacing: 0.02em !important;
|
| 919 |
}
|
| 920 |
+
.tab-nav button.selected {
|
| 921 |
+
border-color: #E8553D !important;
|
| 922 |
+
color: #E8553D !important;
|
| 923 |
+
}
|
| 924 |
+
|
| 925 |
+
/* === Hide default footer === */
|
| 926 |
+
footer.svelte-1rjryqp { display: none !important; }
|
| 927 |
"""
|
| 928 |
|
| 929 |
+
# ---------------------------------------------------------------------------
|
| 930 |
+
# Build language showcase HTML
|
| 931 |
+
# ---------------------------------------------------------------------------
|
| 932 |
+
|
| 933 |
+
def build_lang_grid_html() -> str:
|
| 934 |
+
"""Build language showcase grouped by region, showing all 67 languages."""
|
| 935 |
+
# Group languages by region
|
| 936 |
+
by_region: dict[str, list[tuple[str, dict]]] = {}
|
| 937 |
+
for code, info in LANGUAGES.items():
|
| 938 |
+
region = info.get("region", "Other")
|
| 939 |
+
by_region.setdefault(region, []).append((code, info))
|
| 940 |
+
|
| 941 |
+
html_parts = []
|
| 942 |
+
for region in REGION_ORDER:
|
| 943 |
+
langs = by_region.get(region, [])
|
| 944 |
+
if not langs:
|
| 945 |
+
continue
|
| 946 |
+
html_parts.append(f'<div class="region-label">{region} ({len(langs)})</div>')
|
| 947 |
+
html_parts.append('<div class="lang-grid">')
|
| 948 |
+
for code, info in langs:
|
| 949 |
+
html_parts.append(
|
| 950 |
+
f'<div class="lang-card">'
|
| 951 |
+
f'<div class="greeting">{info["greeting"]}</div>'
|
| 952 |
+
f'<div class="name">{info["flag"]} {info["name"]}</div>'
|
| 953 |
+
f'<div class="family">{info["family"]}</div>'
|
| 954 |
+
f'</div>'
|
| 955 |
+
)
|
| 956 |
+
html_parts.append('</div>')
|
| 957 |
+
return "".join(html_parts)
|
| 958 |
+
|
| 959 |
+
|
| 960 |
+
# ---------------------------------------------------------------------------
|
| 961 |
+
# Gradio App
|
| 962 |
+
# ---------------------------------------------------------------------------
|
| 963 |
+
|
| 964 |
theme = gr.themes.Soft(
|
| 965 |
+
primary_hue="orange",
|
| 966 |
+
secondary_hue="indigo",
|
| 967 |
neutral_hue="slate",
|
| 968 |
font=gr.themes.GoogleFont("Inter"),
|
| 969 |
).set(
|
| 970 |
+
body_background_fill="#FFFBF7",
|
| 971 |
+
body_background_fill_dark="#FFFBF7",
|
| 972 |
+
block_background_fill="#FFFFFF",
|
| 973 |
+
block_background_fill_dark="#FFFFFF",
|
| 974 |
+
block_border_color="rgba(0,0,0,0.06)",
|
| 975 |
+
block_border_color_dark="rgba(0,0,0,0.06)",
|
| 976 |
+
input_background_fill="#FFFBF7",
|
| 977 |
+
input_background_fill_dark="#FFFBF7",
|
| 978 |
+
button_primary_background_fill="#E8553D",
|
| 979 |
+
button_primary_background_fill_dark="#E8553D",
|
| 980 |
+
button_primary_background_fill_hover="#d04830",
|
| 981 |
+
button_primary_background_fill_hover_dark="#d04830",
|
| 982 |
+
button_primary_text_color="white",
|
| 983 |
+
button_secondary_background_fill="white",
|
| 984 |
+
button_secondary_border_color="rgba(0,0,0,0.08)",
|
| 985 |
+
body_text_color="#1A1A2E",
|
| 986 |
+
body_text_color_dark="#1A1A2E",
|
| 987 |
)
|
| 988 |
|
| 989 |
last_updated = BENCHMARKS.get("metadata", {}).get("last_updated", "N/A")
|
| 990 |
|
| 991 |
+
with gr.Blocks(theme=theme, css=CUSTOM_CSS, title="Aetheris — Polyglot AI") as demo:
|
| 992 |
|
| 993 |
+
# === Hero Banner ===
|
| 994 |
gr.HTML("""
|
| 995 |
+
<div class="hero-banner">
|
| 996 |
+
<div class="collab-line">
|
| 997 |
+
CohereLabs<span class="dot"></span>Wayy Research
|
| 998 |
+
</div>
|
| 999 |
+
<h1><span class="coral">Aetheris</span> Playground</h1>
|
| 1000 |
+
<p class="subtitle">
|
| 1001 |
+
A hybrid Mamba-MoE multilingual model — 800M parameters distilled from
|
| 1002 |
+
<strong>tiny-aya-global</strong> (3.35B) with 4.2x compression.
|
| 1003 |
+
Chat in any of 67 languages and it responds in kind.
|
| 1004 |
</p>
|
| 1005 |
+
<div class="stat-row">
|
| 1006 |
+
<span class="stat-pill"><span class="num">4.2x</span> compression</span>
|
| 1007 |
+
<span class="stat-pill blue"><span class="num">800M</span> parameters</span>
|
| 1008 |
+
<span class="stat-pill green"><span class="num">3.1x</span> faster inference</span>
|
| 1009 |
+
<span class="stat-pill amber"><span class="num">67</span> languages</span>
|
| 1010 |
+
<span class="stat-pill"><span class="num">24</span> hybrid layers (SSM+MoE)</span>
|
| 1011 |
+
</div>
|
| 1012 |
</div>
|
| 1013 |
""")
|
| 1014 |
|
| 1015 |
+
# === Tab 1: Chat ===
|
| 1016 |
+
with gr.Tab("Chat", id="chat"):
|
| 1017 |
+
|
| 1018 |
+
gr.HTML("""
|
| 1019 |
+
<p class="section-label">Polyglot Conversation</p>
|
| 1020 |
+
""")
|
| 1021 |
+
gr.HTML("""
|
| 1022 |
+
<div class="bench-notice" style="border-left-color: #2AAA5B;">
|
| 1023 |
+
<strong>🟢 Training in progress</strong> — Aetheris is currently training on RunPod
|
| 1024 |
+
(Stage 1: CKA layer alignment). The chat below runs in <strong>demo mode</strong>
|
| 1025 |
+
with scripted multilingual responses. Once training completes, live inference will
|
| 1026 |
+
activate automatically. Try chatting in different languages!
|
| 1027 |
+
</div>
|
| 1028 |
+
""")
|
| 1029 |
gr.Markdown(
|
| 1030 |
+
"Chat with Aetheris in any language — it detects your language and responds naturally. "
|
| 1031 |
+
"Try switching languages mid-conversation to see multilingual capabilities."
|
|
|
|
| 1032 |
)
|
| 1033 |
|
| 1034 |
with gr.Row():
|
| 1035 |
+
with gr.Column(scale=4):
|
| 1036 |
+
chatbot = gr.Chatbot(
|
| 1037 |
+
height=480,
|
| 1038 |
+
type="messages",
|
| 1039 |
+
label="Aetheris",
|
| 1040 |
+
show_label=True,
|
| 1041 |
+
avatar_images=(None, Path(__file__).parent / "avatar.svg"),
|
| 1042 |
+
elem_classes=["chat-container"],
|
| 1043 |
)
|
| 1044 |
with gr.Row():
|
| 1045 |
+
msg_input = gr.Textbox(
|
| 1046 |
+
placeholder="Type a message in any language...",
|
| 1047 |
+
show_label=False,
|
| 1048 |
+
scale=6,
|
| 1049 |
+
container=False,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1050 |
)
|
| 1051 |
+
send_btn = gr.Button("Send", variant="primary", scale=1, min_width=80)
|
| 1052 |
|
| 1053 |
+
with gr.Accordion("Settings", open=False):
|
| 1054 |
+
with gr.Row():
|
| 1055 |
+
temperature = gr.Slider(0.0, 1.0, value=0.7, step=0.05, label="Temperature")
|
| 1056 |
+
max_tokens = gr.Slider(32, 512, value=192, step=16, label="Max Tokens")
|
| 1057 |
+
clear_btn = gr.Button("Clear Chat", variant="secondary", size="sm")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1058 |
|
| 1059 |
+
with gr.Column(scale=2):
|
| 1060 |
+
gr.HTML('<p class="section-label">Try These</p>')
|
| 1061 |
+
gr.Markdown("Click any example to load it:")
|
| 1062 |
+
for text, lang in POLYGLOT_EXAMPLES[:8]:
|
| 1063 |
+
flag = LANGUAGES.get(lang, {}).get("flag", "")
|
| 1064 |
+
btn = gr.Button(
|
| 1065 |
+
f"{flag} {text[:50]}{'...' if len(text) > 50 else ''}",
|
| 1066 |
+
variant="secondary",
|
| 1067 |
+
size="sm",
|
| 1068 |
+
)
|
| 1069 |
+
btn.click(fn=lambda t=text: t, outputs=[msg_input])
|
| 1070 |
+
|
| 1071 |
+
# Wire up chat
|
| 1072 |
+
msg_input.submit(fn=chat_respond, inputs=[msg_input, chatbot, temperature, max_tokens],
|
| 1073 |
+
outputs=[chatbot, msg_input])
|
| 1074 |
+
send_btn.click(fn=chat_respond, inputs=[msg_input, chatbot, temperature, max_tokens],
|
| 1075 |
+
outputs=[chatbot, msg_input])
|
| 1076 |
+
clear_btn.click(fn=lambda: ([], ""), outputs=[chatbot, msg_input])
|
| 1077 |
+
|
| 1078 |
+
# === Tab 2: Languages ===
|
| 1079 |
+
with gr.Tab("Languages", id="languages"):
|
| 1080 |
+
gr.HTML('<p class="section-label">Full Language Coverage</p>')
|
| 1081 |
+
gr.Markdown(
|
| 1082 |
+
f"### {len(LANGUAGES)} Languages Across 6 Regions\n"
|
| 1083 |
+
"Aetheris inherits multilingual coverage from CohereLabs/tiny-aya-global, spanning "
|
| 1084 |
+
"13 language families, 15 scripts, and languages from every inhabited continent."
|
| 1085 |
)
|
| 1086 |
+
gr.HTML(build_lang_grid_html())
|
| 1087 |
|
| 1088 |
+
# Summary stats
|
| 1089 |
+
families = set(info["family"] for info in LANGUAGES.values())
|
| 1090 |
+
regions = set(info["region"] for info in LANGUAGES.values())
|
| 1091 |
+
gr.HTML(f"""
|
| 1092 |
+
<div class="stat-row" style="margin-top: 20px; justify-content: center;">
|
| 1093 |
+
<span class="stat-pill"><span class="num">{len(LANGUAGES)}</span> languages</span>
|
| 1094 |
+
<span class="stat-pill blue"><span class="num">{len(families)}</span> language families</span>
|
| 1095 |
+
<span class="stat-pill green"><span class="num">{len(regions)}</span> regions</span>
|
| 1096 |
+
<span class="stat-pill amber"><span class="num">15</span> scripts</span>
|
| 1097 |
+
</div>
|
| 1098 |
+
""")
|
| 1099 |
|
| 1100 |
+
# === Tab 3: Benchmarks ===
|
|
|
|
|
|
|
| 1101 |
with gr.Tab("Benchmarks", id="benchmarks"):
|
| 1102 |
|
| 1103 |
gr.HTML(f"""
|
| 1104 |
+
<div class="bench-notice">
|
| 1105 |
+
<strong>Projected values</strong> — benchmark data below uses research-validated projections.
|
| 1106 |
+
Real results will populate automatically as training completes.
|
| 1107 |
+
Last updated: <strong>{last_updated}</strong>
|
| 1108 |
</div>
|
| 1109 |
""")
|
| 1110 |
|
| 1111 |
+
# Model comparison table
|
| 1112 |
+
gr.HTML('<p class="section-label">Model Comparison</p>')
|
| 1113 |
+
gr.Markdown("### Teacher vs Student Architecture")
|
| 1114 |
model_table = make_model_info_table()
|
| 1115 |
+
gr.Dataframe(value=model_table, interactive=False, wrap=True)
|
| 1116 |
+
|
| 1117 |
+
# mGSM
|
| 1118 |
+
gr.HTML('<br><p class="section-label">Math Reasoning</p>')
|
| 1119 |
+
gr.Markdown("### mGSM — Multilingual Grade School Math")
|
| 1120 |
+
gr.Markdown("*Multi-step arithmetic reasoning across languages. 8-shot evaluation.*")
|
| 1121 |
+
gr.Plot(value=make_benchmark_chart("mgsm", "mGSM Accuracy by Language"))
|
| 1122 |
+
|
| 1123 |
+
# XCOPA
|
| 1124 |
+
gr.HTML('<p class="section-label">Causal Reasoning</p>')
|
| 1125 |
+
gr.Markdown("### XCOPA — Cross-lingual Causal Reasoning")
|
| 1126 |
+
gr.Markdown("*Choice of plausible alternatives — tests commonsense and causal reasoning.*")
|
| 1127 |
+
gr.Plot(value=make_benchmark_chart("xcopa", "XCOPA Accuracy by Language"))
|
| 1128 |
+
|
| 1129 |
+
# Retention
|
| 1130 |
+
gr.HTML('<p class="section-label">Compression Quality</p>')
|
| 1131 |
+
gr.Markdown("### Quality Retention Across Languages")
|
| 1132 |
+
gr.Markdown("*Percentage of teacher performance preserved at 4.2x compression. Target: >80%.*")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1133 |
gr.Plot(value=make_retention_chart())
|
| 1134 |
|
| 1135 |
+
# Throughput
|
| 1136 |
+
gr.HTML('<p class="section-label">Performance</p>')
|
| 1137 |
+
gr.Markdown("### Inference Throughput")
|
| 1138 |
+
gr.Markdown("*Batch size 1, sequence length 256. Student achieves ~3x speedup with 4x memory reduction.*")
|
|
|
|
|
|
|
| 1139 |
gr.Plot(value=make_throughput_chart())
|
| 1140 |
|
| 1141 |
+
# Equity
|
| 1142 |
+
gr.HTML('<p class="section-label">Multilingual Equity</p>')
|
| 1143 |
+
gr.Markdown("### Degradation Equity Score")
|
| 1144 |
gr.Markdown(
|
| 1145 |
+
"*DES measures fairness of quality loss across language families. "
|
| 1146 |
+
"0 = perfectly equitable, 1 = maximum inequity. "
|
| 1147 |
+
"Red = needs remediation.*"
|
| 1148 |
)
|
| 1149 |
gr.Plot(value=make_equity_chart())
|
| 1150 |
|
| 1151 |
equity_data = BENCHMARKS.get("equity", {})
|
| 1152 |
high_risk = equity_data.get("high_risk_languages", [])
|
| 1153 |
+
if high_risk:
|
| 1154 |
+
gr.Markdown(f"**High-risk languages** (targeted for extra training): `{'`, `'.join(high_risk)}`")
|
|
|
|
|
|
|
| 1155 |
|
| 1156 |
+
# Training progress
|
| 1157 |
+
gr.HTML('<p class="section-label">Training</p>')
|
| 1158 |
+
gr.Markdown("### Training Progress")
|
| 1159 |
+
gr.Markdown("*Key metrics tracked across the 3-stage distillation pipeline.*")
|
|
|
|
|
|
|
| 1160 |
gr.Plot(value=make_training_chart())
|
| 1161 |
|
| 1162 |
+
# === Footer ===
|
| 1163 |
+
gr.HTML(f"""
|
| 1164 |
+
<div class="footer-section">
|
| 1165 |
+
<div class="collab-title">
|
| 1166 |
+
<span class="coral">CohereLabs</span> x <span class="indigo">Wayy Research</span>
|
| 1167 |
+
</div>
|
| 1168 |
+
<div class="tagline">"People for research, research for people."</div>
|
| 1169 |
+
<div class="links">
|
| 1170 |
+
<a href="https://huggingface.co/wayyresearch/aetheris">Model</a> |
|
| 1171 |
+
<a href="https://github.com/Wayy-Research/project-aya">Training Code</a> |
|
| 1172 |
+
<a href="https://github.com/Wayy-Research/aetheris">Aetheris</a> |
|
| 1173 |
+
<a href="https://huggingface.co/CohereLabs/tiny-aya-global">Teacher Model</a>
|
| 1174 |
+
<br>Buffalo, NY | Est. 2024 | Last updated: {last_updated}
|
| 1175 |
+
</div>
|
| 1176 |
+
</div>
|
| 1177 |
+
""")
|
| 1178 |
|
| 1179 |
|
| 1180 |
if __name__ == "__main__":
|
requirements.txt
CHANGED
|
@@ -4,3 +4,5 @@ torch
|
|
| 4 |
plotly
|
| 5 |
pandas
|
| 6 |
numpy
|
|
|
|
|
|
|
|
|
| 4 |
plotly
|
| 5 |
pandas
|
| 6 |
numpy
|
| 7 |
+
huggingface_hub>=0.20
|
| 8 |
+
pyyaml
|