rcgalbo Claude Opus 4.6 commited on
Commit
1734d91
·
1 Parent(s): 050a476

Use pruned 536M model with vocab mapping support

Browse files

- Model shrunk from 722M to 536M via vocabulary pruning (255K→80K tokens)
- Add custom autoregressive generation (no HF .generate() dependency)
- Add vocab_mapping.json support for tokenizer→model ID translation
- Update model size references throughout

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Files changed (2) hide show
  1. app.py +953 -520
  2. requirements.txt +2 -0
app.py CHANGED
@@ -1,14 +1,16 @@
1
  """
2
- Aetheris Playground - Wayy Research
3
- Interactive playground and benchmark dashboard for the Aetheris model.
4
 
5
- Aetheris is a ~800M Hybrid Mamba-MoE multilingual model distilled from
6
- CohereLabs/tiny-aya-global (3.35B). This Space provides text generation
7
- and side-by-side benchmark comparisons.
8
  """
9
 
10
  import json
11
  import os
 
 
12
  from datetime import datetime
13
  from pathlib import Path
14
  from typing import Any
@@ -16,55 +18,126 @@ from typing import Any
16
  import gradio as gr
17
  import numpy as np
18
  import pandas as pd
19
- import plotly.express as px
20
  import plotly.graph_objects as go
 
21
  from plotly.subplots import make_subplots
22
 
23
  # ---------------------------------------------------------------------------
24
- # Constants
25
  # ---------------------------------------------------------------------------
26
 
27
  BENCHMARK_PATH = Path(__file__).parent / "benchmark_results.json"
28
 
29
- CORE_LANGUAGES = {
30
- "en": "English",
31
- "es": "Spanish",
32
- "hi": "Hindi",
33
- "zh": "Chinese",
34
- "ar": "Arabic",
35
- "sw": "Swahili",
36
- "tr": "Turkish",
37
- "ja": "Japanese",
38
- "id": "Indonesian",
39
- "te": "Telugu",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  }
41
 
42
- LANG_PROMPTS = {
43
- "en": "Translate the following sentence to French: 'The weather is beautiful today.'",
44
- "es": "Escribe un breve poema sobre la naturaleza.",
45
- "hi": "भारत के बारे में तीन रोचक तथ्य बताइए।",
46
- "zh": "请用简单的语言解释量子计算。",
47
- "ar": "اكتب قصة قصيرة عن السفر عبر الزمن.",
48
- "sw": "Eleza umuhimu wa elimu kwa vijana.",
49
- "tr": "Yapay zekanın geleceği hakkında kısa bir yazı yazın.",
50
- "ja": "日本の伝統文化について説明してください。",
51
- "id": "Jelaskan pentingnya menjaga lingkungan hidup.",
52
- "te": "భారతదేశంలో విద్య యొక్క ప్రాముఖ్యతను వివరించండి.",
53
- }
54
-
55
- PLOTLY_TEMPLATE = "plotly_dark"
56
- ACCENT_COLOR = "#818cf8" # indigo-400
57
- ACCENT_COLOR_2 = "#f472b6" # pink-400
58
- TEACHER_COLOR = "#60a5fa" # blue-400
59
- STUDENT_COLOR = "#34d399" # emerald-400
60
-
61
- # ---------------------------------------------------------------------------
62
- # Load benchmark data
63
- # ---------------------------------------------------------------------------
 
 
 
 
 
64
 
65
 
66
  def load_benchmarks() -> dict[str, Any]:
67
- """Load benchmark results from JSON file."""
68
  if BENCHMARK_PATH.exists():
69
  with open(BENCHMARK_PATH) as f:
70
  return json.load(f)
@@ -74,268 +147,541 @@ def load_benchmarks() -> dict[str, Any]:
74
  BENCHMARKS = load_benchmarks()
75
 
76
  # ---------------------------------------------------------------------------
77
- # Model loading (teacher fallback)
78
  # ---------------------------------------------------------------------------
79
 
80
  model = None
81
  tokenizer = None
 
 
82
 
83
 
84
- def load_model() -> tuple[Any, Any]:
85
- """Load teacher model as fallback demo. Returns (model, tokenizer)."""
86
- global model, tokenizer
87
- if model is not None:
88
  return model, tokenizer
89
 
 
90
  try:
91
- from transformers import AutoModelForCausalLM, AutoTokenizer
92
- import torch
93
-
94
- model_id = "CohereLabs/tiny-aya-global"
95
- tokenizer = AutoTokenizer.from_pretrained(model_id)
96
- model = AutoModelForCausalLM.from_pretrained(
97
- model_id,
98
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
99
- device_map="auto" if torch.cuda.is_available() else None,
100
- low_cpu_mem_usage=True,
 
 
 
 
 
 
 
 
 
 
 
101
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  return model, tokenizer
 
103
  except Exception as e:
104
- print(f"Model loading failed: {e}")
 
 
 
105
  return None, None
106
 
107
 
108
  # ---------------------------------------------------------------------------
109
- # Generation
110
  # ---------------------------------------------------------------------------
111
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
- def generate_text(
114
- prompt: str,
115
- language: str,
116
- max_tokens: int,
117
- temperature: float,
118
- ) -> str:
119
- """Generate text using the loaded model."""
120
- if not prompt.strip():
121
- return "Please enter a prompt."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
  m, tok = load_model()
124
 
 
125
  if m is None or tok is None:
126
- return (
127
- "[Model not available in this environment]\n\n"
128
- "The teacher model (CohereLabs/tiny-aya-global) requires ~7GB of memory. "
129
- "On HuggingFace Spaces with GPU, this will load automatically.\n\n"
130
- f"Your prompt ({CORE_LANGUAGES.get(language, language)}): {prompt}"
131
  )
 
 
 
132
 
133
  try:
134
  import torch
135
 
136
- inputs = tok(prompt, return_tensors="pt")
137
- if torch.cuda.is_available():
138
- inputs = {k: v.cuda() for k, v in inputs.items()}
139
-
140
- with torch.no_grad():
141
- outputs = m.generate(
142
- **inputs,
143
- max_new_tokens=int(max_tokens),
144
- temperature=float(temperature) if temperature > 0 else 1.0,
145
- do_sample=temperature > 0,
146
- top_p=0.9,
147
- repetition_penalty=1.1,
148
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
- generated = tok.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
151
- return generated.strip()
152
  except Exception as e:
153
- return f"Generation error: {e}"
 
 
 
 
 
 
 
 
 
154
 
155
 
156
  # ---------------------------------------------------------------------------
157
- # Benchmark visualizations
158
  # ---------------------------------------------------------------------------
159
 
 
 
 
 
 
 
 
 
160
 
161
  def make_model_info_table() -> pd.DataFrame:
162
- """Create side-by-side model comparison table."""
163
  info = BENCHMARKS.get("model_info", {})
164
  teacher = info.get("teacher", {})
165
  student = info.get("student", {})
166
-
167
  rows = [
168
- ("Organization", teacher.get("org", "CohereLabs"), student.get("org", "Wayy Research")),
 
169
  ("Parameters", f"{teacher.get('params_m', 3350)}M", f"{student.get('params_m', 800)}M"),
170
- ("Architecture", teacher.get("architecture", "Transformer"), student.get("architecture", "Hybrid Mamba-MoE")),
171
  ("Layers", str(teacher.get("layers", 36)), str(student.get("layers", 24))),
172
- ("Hidden Dim", str(teacher.get("hidden_dim", 2048)), str(student.get("hidden_dim", 768))),
173
- ("Attention", teacher.get("attention", "GQA 16/4"), "SSM (Mamba) + MoE routing"),
174
  ("Experts", "N/A (dense)", f"{student.get('num_experts', 4)} (top-{student.get('top_k', 1)})"),
175
- ("Vocab Size", f"{teacher.get('vocab_size', 262144):,}", f"{student.get('vocab_size', 262144):,}"),
176
- ("Max Seq Len", str(teacher.get("max_seq_len", 8192)), str(student.get("max_seq_len", 512))),
177
- ("Languages", str(teacher.get("languages", 101)), str(student.get("languages", 101))),
178
  ("Compression", "1.0x (baseline)", f"{student.get('compression_ratio', 4.2)}x"),
179
  ]
180
-
181
- return pd.DataFrame(rows, columns=["Metric", "Teacher (tiny-aya-global)", "Student (Aetheris)"])
182
 
183
 
184
  def make_benchmark_chart(benchmark_key: str, title: str) -> go.Figure:
185
- """Create grouped bar chart for a benchmark (mGSM or XCOPA)."""
186
  data = BENCHMARKS.get(benchmark_key, {})
187
  langs = data.get("languages", {})
188
  avg = data.get("average", {})
189
 
190
  lang_codes = list(langs.keys())
191
- lang_names = [CORE_LANGUAGES.get(lc, lc) for lc in lang_codes]
192
  teacher_scores = [langs[lc]["teacher"] for lc in lang_codes]
193
  student_scores = [langs[lc]["student"] for lc in lang_codes]
194
 
195
- # Add average
196
- lang_names.append("AVERAGE")
197
  teacher_scores.append(avg.get("teacher", 0))
198
  student_scores.append(avg.get("student", 0))
199
 
200
  fig = go.Figure()
201
  fig.add_trace(go.Bar(
202
  name="Teacher (tiny-aya-global)",
203
- x=lang_names,
204
- y=teacher_scores,
205
- marker_color=TEACHER_COLOR,
206
  text=[f"{s:.1f}" for s in teacher_scores],
207
- textposition="outside",
208
- textfont=dict(size=10),
209
  ))
210
  fig.add_trace(go.Bar(
211
  name="Student (Aetheris)",
212
- x=lang_names,
213
- y=student_scores,
214
- marker_color=STUDENT_COLOR,
215
  text=[f"{s:.1f}" for s in student_scores],
216
- textposition="outside",
217
- textfont=dict(size=10),
218
  ))
219
 
220
- metric = data.get("metric", "accuracy (%)")
221
  fig.update_layout(
222
- title=dict(text=title, font=dict(size=18)),
 
223
  xaxis_title="Language",
224
- yaxis_title=metric,
225
  barmode="group",
226
- template=PLOTLY_TEMPLATE,
227
- height=450,
228
- legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
229
- margin=dict(t=80, b=60),
230
  )
231
-
232
  return fig
233
 
234
 
235
  def make_retention_chart() -> go.Figure:
236
- """Create a retention heatmap showing student/teacher ratio per benchmark per language."""
237
  mgsm = BENCHMARKS.get("mgsm", {}).get("languages", {})
238
  xcopa = BENCHMARKS.get("xcopa", {}).get("languages", {})
239
 
240
  lang_codes = list(mgsm.keys())
241
- lang_names = [CORE_LANGUAGES.get(lc, lc) for lc in lang_codes]
242
 
243
- mgsm_retention = [
244
- mgsm[lc]["student"] / mgsm[lc]["teacher"] * 100 if mgsm[lc]["teacher"] > 0 else 0
245
- for lc in lang_codes
246
- ]
247
- xcopa_retention = [
248
- xcopa[lc]["student"] / xcopa[lc]["teacher"] * 100 if lc in xcopa and xcopa[lc]["teacher"] > 0 else 0
249
- for lc in lang_codes
250
- ]
251
 
252
  fig = go.Figure()
253
- fig.add_trace(go.Scatter(
254
- x=lang_names, y=mgsm_retention,
255
- mode="lines+markers+text",
256
- name="mGSM Retention",
257
- line=dict(color=TEACHER_COLOR, width=2),
258
- marker=dict(size=8),
259
- text=[f"{r:.0f}%" for r in mgsm_retention],
260
- textposition="top center",
261
- textfont=dict(size=9),
262
- ))
263
- fig.add_trace(go.Scatter(
264
- x=lang_names, y=xcopa_retention,
265
- mode="lines+markers+text",
266
- name="XCOPA Retention",
267
- line=dict(color=STUDENT_COLOR, width=2),
268
- marker=dict(size=8),
269
- text=[f"{r:.0f}%" for r in xcopa_retention],
270
- textposition="bottom center",
271
- textfont=dict(size=9),
272
- ))
273
-
274
- # Reference line at 80% retention
275
- fig.add_hline(y=80, line_dash="dash", line_color="#f59e0b",
276
  annotation_text="80% retention target", annotation_position="top right")
277
 
278
- fig.update_layout(
279
- title=dict(text="Quality Retention: Student / Teacher (%)", font=dict(size=18)),
280
- yaxis_title="Retention (%)",
281
- xaxis_title="Language",
282
- template=PLOTLY_TEMPLATE,
283
- height=400,
284
- yaxis=dict(range=[60, 105]),
285
- legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
286
- )
287
-
288
  return fig
289
 
290
 
291
  def make_throughput_chart() -> go.Figure:
292
- """Create throughput comparison bar charts."""
293
  tp = BENCHMARKS.get("throughput", {})
294
  teacher = tp.get("teacher", {})
295
  student = tp.get("student", {})
296
 
297
- metrics = ["Tokens/sec", "TTFT (ms)", "Memory (MB)", "P99 Latency (ms)"]
298
- teacher_vals = [
299
- teacher.get("tokens_per_sec", 0),
300
- teacher.get("ttft_ms", 0),
301
- teacher.get("memory_mb", 0),
302
- teacher.get("latency_p99_ms", 0),
303
- ]
304
- student_vals = [
305
- student.get("tokens_per_sec", 0),
306
- student.get("ttft_ms", 0),
307
- student.get("memory_mb", 0),
308
- student.get("latency_p99_ms", 0),
309
- ]
310
-
311
- fig = make_subplots(
312
- rows=1, cols=4,
313
- subplot_titles=metrics,
314
- horizontal_spacing=0.08,
315
- )
316
 
317
- for i, (metric, tv, sv) in enumerate(zip(metrics, teacher_vals, student_vals)):
318
- col = i + 1
319
- fig.add_trace(go.Bar(
320
- x=["Teacher", "Student"], y=[tv, sv],
321
- marker_color=[TEACHER_COLOR, STUDENT_COLOR],
322
- text=[str(tv), str(sv)],
323
- textposition="outside",
324
- showlegend=False,
325
- ), row=1, col=col)
326
-
327
- fig.update_layout(
328
- title=dict(text="Inference Performance (batch=1, seq_len=256)", font=dict(size=18)),
329
- template=PLOTLY_TEMPLATE,
330
- height=380,
331
- margin=dict(t=80, b=40),
332
- )
333
 
 
 
334
  return fig
335
 
336
 
337
  def make_equity_chart() -> go.Figure:
338
- """Create multilingual equity visualization."""
339
  equity = BENCHMARKS.get("equity", {})
340
  families = equity.get("language_families", {})
341
 
@@ -343,405 +689,492 @@ def make_equity_chart() -> go.Figure:
343
  des_vals = [families[n]["des"] for n in names]
344
  retention_vals = [families[n]["avg_retention"] * 100 for n in names]
345
 
346
- fig = make_subplots(
347
- rows=1, cols=2,
348
- subplot_titles=[
349
- "Degradation Equity Score by Family (lower = more equitable)",
350
- "Average Quality Retention by Family (%)"
351
- ],
352
- horizontal_spacing=0.12,
353
- )
354
-
355
- # Sort by DES for the left chart
356
  sorted_idx = np.argsort(des_vals)[::-1]
357
  sorted_names = [names[i] for i in sorted_idx]
358
  sorted_des = [des_vals[i] for i in sorted_idx]
359
 
360
- colors = [
361
- "#ef4444" if d > 0.4 else "#f59e0b" if d > 0.3 else STUDENT_COLOR
362
- for d in sorted_des
363
- ]
364
-
365
- fig.add_trace(go.Bar(
366
- x=sorted_des, y=sorted_names,
367
- orientation="h",
368
- marker_color=colors,
369
- text=[f"{d:.2f}" for d in sorted_des],
370
- textposition="outside",
371
- showlegend=False,
372
- ), row=1, col=1)
373
-
374
- # Sort by retention for the right chart
375
- sorted_idx_r = np.argsort(retention_vals)
376
- sorted_names_r = [names[i] for i in sorted_idx_r]
377
- sorted_ret = [retention_vals[i] for i in sorted_idx_r]
378
-
379
- colors_r = [
380
- "#ef4444" if r < 77 else "#f59e0b" if r < 82 else STUDENT_COLOR
381
- for r in sorted_ret
382
- ]
383
-
384
- fig.add_trace(go.Bar(
385
- x=sorted_ret, y=sorted_names_r,
386
- orientation="h",
387
- marker_color=colors_r,
388
- text=[f"{r:.0f}%" for r in sorted_ret],
389
- textposition="outside",
390
- showlegend=False,
391
- ), row=1, col=2)
392
-
393
- fig.update_layout(
394
- template=PLOTLY_TEMPLATE,
395
- height=400,
396
- margin=dict(t=60, b=40, l=120),
397
- )
398
 
 
 
 
 
 
 
399
  return fig
400
 
401
 
402
  def make_training_chart() -> go.Figure:
403
- """Create training progress curves."""
404
  progress = BENCHMARKS.get("training_progress", {})
405
  steps = progress.get("steps", [])
406
  kl = progress.get("kl_divergence", [])
407
  cka = progress.get("cka_alignment", [])
408
  loss = progress.get("total_loss", [])
409
- des = progress.get("des_score", [])
410
-
411
- fig = make_subplots(
412
- rows=2, cols=2,
413
- subplot_titles=[
414
- "KL Divergence (lower = better)",
415
- "CKA Alignment (higher = better)",
416
- "Total Loss",
417
- "DES Score (lower = more equitable)",
418
- ],
419
- vertical_spacing=0.15,
420
- horizontal_spacing=0.1,
421
- )
422
 
423
- fig.add_trace(go.Scatter(
424
- x=steps, y=kl, mode="lines+markers",
425
- line=dict(color=ACCENT_COLOR, width=2),
426
- marker=dict(size=5), showlegend=False,
427
- ), row=1, col=1)
428
-
429
- fig.add_trace(go.Scatter(
430
- x=steps, y=cka, mode="lines+markers",
431
- line=dict(color=STUDENT_COLOR, width=2),
432
- marker=dict(size=5), showlegend=False,
433
- ), row=1, col=2)
434
-
435
- fig.add_trace(go.Scatter(
436
- x=steps, y=loss, mode="lines+markers",
437
- line=dict(color=ACCENT_COLOR_2, width=2),
438
- marker=dict(size=5), showlegend=False,
439
- ), row=2, col=1)
440
-
441
- fig.add_trace(go.Scatter(
442
- x=steps, y=des, mode="lines+markers",
443
- line=dict(color="#f59e0b", width=2),
444
- marker=dict(size=5), showlegend=False,
445
- ), row=2, col=2)
446
-
447
- # Add target lines
448
- fig.add_hline(y=0.75, line_dash="dash", line_color="#4ade80",
449
- annotation_text="target", row=1, col=2)
450
- fig.add_hline(y=0.30, line_dash="dash", line_color="#4ade80",
451
- annotation_text="target", row=2, col=2)
452
-
453
- fig.update_xaxes(title_text="Training Steps", row=2, col=1)
454
- fig.update_xaxes(title_text="Training Steps", row=2, col=2)
455
 
456
- fig.update_layout(
457
- template=PLOTLY_TEMPLATE,
458
- height=550,
459
- margin=dict(t=60, b=60),
460
- )
461
 
 
 
 
462
  return fig
463
 
464
 
465
- def make_benchmark_table(benchmark_key: str) -> pd.DataFrame:
466
- """Create a results table for mGSM or XCOPA."""
467
- data = BENCHMARKS.get(benchmark_key, {})
468
- langs = data.get("languages", {})
469
- avg = data.get("average", {})
470
-
471
- rows = []
472
- for lc, scores in langs.items():
473
- t = scores["teacher"]
474
- s = scores["student"]
475
- retention = (s / t * 100) if t > 0 else 0
476
- rows.append({
477
- "Language": f"{CORE_LANGUAGES.get(lc, lc)} ({lc})",
478
- "Teacher": f"{t:.1f}",
479
- "Student": f"{s:.1f}",
480
- "Retention": f"{retention:.0f}%",
481
- "Gap": f"{t - s:+.1f}",
482
- })
483
-
484
- # Average row
485
- t_avg = avg.get("teacher", 0)
486
- s_avg = avg.get("student", 0)
487
- ret_avg = (s_avg / t_avg * 100) if t_avg > 0 else 0
488
- rows.append({
489
- "Language": "AVERAGE",
490
- "Teacher": f"{t_avg:.1f}",
491
- "Student": f"{s_avg:.1f}",
492
- "Retention": f"{ret_avg:.0f}%",
493
- "Gap": f"{t_avg - s_avg:+.1f}",
494
- })
495
-
496
- return pd.DataFrame(rows)
497
-
498
-
499
  # ---------------------------------------------------------------------------
500
- # Build the Gradio app
501
  # ---------------------------------------------------------------------------
502
 
503
  CUSTOM_CSS = """
504
- .header-banner {
505
- background: linear-gradient(135deg, #4338ca 0%, #7c3aed 50%, #a855f7 100%);
506
- padding: 24px 32px;
507
- border-radius: 12px;
508
- margin-bottom: 16px;
 
 
509
  }
510
- .header-banner h1 {
511
- color: white !important;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
512
  margin: 0 0 4px 0 !important;
513
- font-size: 2em !important;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
514
  }
515
- .header-banner p {
516
- color: #e0e7ff !important;
517
- margin: 0 !important;
 
 
 
 
 
 
 
518
  }
519
- .stat-card {
520
- background: #1e1b4b;
521
- border: 1px solid #4338ca;
522
- border-radius: 10px;
523
- padding: 16px;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
524
  text-align: center;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
525
  }
526
- .stat-card h3 {
527
- color: #818cf8 !important;
528
- margin: 0 !important;
529
- font-size: 2em !important;
530
  }
531
- .stat-card p {
532
- color: #c7d2fe !important;
533
- margin: 4px 0 0 0 !important;
534
- font-size: 0.85em !important;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
535
  }
536
- .placeholder-notice {
537
- background: #1e1b4b;
538
- border: 1px solid #f59e0b;
539
- border-radius: 8px;
540
- padding: 12px 16px;
541
- color: #fcd34d;
542
- font-size: 0.9em;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
543
  }
544
- footer { display: none !important; }
 
 
 
 
 
 
545
  """
546
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
547
  theme = gr.themes.Soft(
548
- primary_hue="indigo",
549
- secondary_hue="purple",
550
  neutral_hue="slate",
551
  font=gr.themes.GoogleFont("Inter"),
552
  ).set(
553
- body_background_fill="#0f0d1a",
554
- body_background_fill_dark="#0f0d1a",
555
- block_background_fill="#1a1730",
556
- block_background_fill_dark="#1a1730",
557
- block_border_color="#2d2a4a",
558
- block_border_color_dark="#2d2a4a",
559
- input_background_fill="#1e1b3a",
560
- input_background_fill_dark="#1e1b3a",
561
- button_primary_background_fill="#4338ca",
562
- button_primary_background_fill_dark="#4338ca",
563
- button_primary_background_fill_hover="#4f46e5",
564
- button_primary_background_fill_hover_dark="#4f46e5",
 
 
 
 
 
565
  )
566
 
567
  last_updated = BENCHMARKS.get("metadata", {}).get("last_updated", "N/A")
568
 
569
- with gr.Blocks(theme=theme, css=CUSTOM_CSS, title="Aetheris Playground") as demo:
570
 
571
- # --- Header ---
572
  gr.HTML("""
573
- <div class="header-banner">
574
- <h1>Aetheris Playground</h1>
575
- <p>
576
- Hybrid Mamba-MoE multilingual model &mdash; 4.2x compression from tiny-aya-global
577
- &nbsp;|&nbsp; <strong>Wayy Research</strong> &mdash; Buffalo, NY
 
 
 
 
578
  </p>
 
 
 
 
 
 
 
579
  </div>
580
  """)
581
 
582
- # --- Key stats row ---
583
- with gr.Row():
584
- gr.HTML("""<div class="stat-card"><h3>4.2x</h3><p>Compression Ratio</p></div>""")
585
- gr.HTML("""<div class="stat-card"><h3>800M</h3><p>Student Parameters</p></div>""")
586
- gr.HTML("""<div class="stat-card"><h3>3.1x</h3><p>Throughput Speedup</p></div>""")
587
- gr.HTML("""<div class="stat-card"><h3>101</h3><p>Languages</p></div>""")
588
- gr.HTML(f"""<div class="stat-card"><h3>0.34</h3><p>DES (Equity Score)</p></div>""")
589
-
590
- # ===================================================================
591
- # TAB 1: Playground
592
- # ===================================================================
593
- with gr.Tab("Playground", id="playground"):
 
 
594
  gr.Markdown(
595
- "### Generate text with the teacher model\n"
596
- "*The student model (Aetheris) will be swapped in once Stage 2 training completes. "
597
- "Currently running CohereLabs/tiny-aya-global as a demo.*"
598
  )
599
 
600
  with gr.Row():
601
- with gr.Column(scale=3):
602
- prompt_input = gr.Textbox(
603
- label="Prompt",
604
- placeholder="Enter your prompt here, or select a language for an example...",
605
- lines=4,
 
 
 
606
  )
607
  with gr.Row():
608
- lang_select = gr.Dropdown(
609
- choices=[(f"{name} ({code})", code) for code, name in CORE_LANGUAGES.items()],
610
- value="en",
611
- label="Language",
612
- scale=2,
613
- )
614
- max_tokens_slider = gr.Slider(
615
- minimum=32, maximum=512, value=128, step=16,
616
- label="Max Tokens",
617
- scale=2,
618
- )
619
- temp_slider = gr.Slider(
620
- minimum=0.0, maximum=1.0, value=0.7, step=0.05,
621
- label="Temperature",
622
- scale=2,
623
  )
 
624
 
625
- with gr.Row():
626
- example_btn = gr.Button("Load Example Prompt", variant="secondary")
627
- generate_btn = gr.Button("Generate", variant="primary")
628
-
629
- with gr.Column(scale=3):
630
- output_text = gr.Textbox(
631
- label="Generated Text",
632
- lines=12,
633
- interactive=False,
634
- )
635
 
636
- def load_example(lang: str) -> str:
637
- return LANG_PROMPTS.get(lang, LANG_PROMPTS["en"])
638
-
639
- example_btn.click(
640
- fn=load_example,
641
- inputs=[lang_select],
642
- outputs=[prompt_input],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
643
  )
 
644
 
645
- generate_btn.click(
646
- fn=generate_text,
647
- inputs=[prompt_input, lang_select, max_tokens_slider, temp_slider],
648
- outputs=[output_text],
649
- )
 
 
 
 
 
 
650
 
651
- # ===================================================================
652
- # TAB 2: Benchmark Dashboard
653
- # ===================================================================
654
  with gr.Tab("Benchmarks", id="benchmarks"):
655
 
656
  gr.HTML(f"""
657
- <div class="placeholder-notice">
658
- Data below uses <strong>projected placeholder values</strong> based on distillation research.
659
- Real benchmark results will be populated as Stage 2 training completes.
660
- &nbsp;|&nbsp; Last updated: <strong>{last_updated}</strong>
661
  </div>
662
  """)
663
 
664
- # --- Model Info ---
665
- gr.Markdown("## Model Comparison")
 
666
  model_table = make_model_info_table()
667
- gr.Dataframe(
668
- value=model_table,
669
- headers=list(model_table.columns),
670
- interactive=False,
671
- wrap=True,
672
- )
673
-
674
- # --- mGSM ---
675
- gr.Markdown("## mGSM: Math Reasoning")
676
- gr.Markdown(
677
- "*Multilingual Grade School Math -- tests multi-step arithmetic reasoning across languages.*"
678
- )
679
- with gr.Row():
680
- with gr.Column(scale=3):
681
- gr.Plot(value=make_benchmark_chart("mgsm", "mGSM Accuracy by Language"))
682
- with gr.Column(scale=2):
683
- mgsm_table = make_benchmark_table("mgsm")
684
- gr.Dataframe(value=mgsm_table, interactive=False, wrap=True)
685
-
686
- # --- XCOPA ---
687
- gr.Markdown("## XCOPA: Causal Reasoning")
688
- gr.Markdown(
689
- "*Cross-lingual Choice of Plausible Alternatives -- tests causal and commonsense reasoning.*"
690
- )
691
- with gr.Row():
692
- with gr.Column(scale=3):
693
- gr.Plot(value=make_benchmark_chart("xcopa", "XCOPA Accuracy by Language"))
694
- with gr.Column(scale=2):
695
- xcopa_table = make_benchmark_table("xcopa")
696
- gr.Dataframe(value=xcopa_table, interactive=False, wrap=True)
697
-
698
- # --- Quality Retention ---
699
- gr.Markdown("## Quality Retention Across Languages")
700
- gr.Markdown(
701
- "*Percentage of teacher performance retained by the student at 4.2x compression. "
702
- "Target: >80% retention across all languages.*"
703
- )
704
  gr.Plot(value=make_retention_chart())
705
 
706
- # --- Throughput ---
707
- gr.Markdown("## Inference Throughput")
708
- gr.Markdown(
709
- "*Measured on A100 80GB, batch size 1, sequence length 256. "
710
- "Student achieves ~3x throughput with 4x memory reduction.*"
711
- )
712
  gr.Plot(value=make_throughput_chart())
713
 
714
- # --- Equity ---
715
- gr.Markdown("## Multilingual Equity")
 
716
  gr.Markdown(
717
- "*DES (Degradation Equity Score) measures how fairly quality loss is distributed "
718
- "across language families. 0 = perfectly equitable, 1 = maximum inequity. "
719
- "Red indicates families needing targeted remediation.*"
720
  )
721
  gr.Plot(value=make_equity_chart())
722
 
723
  equity_data = BENCHMARKS.get("equity", {})
724
  high_risk = equity_data.get("high_risk_languages", [])
725
- gr.Markdown(
726
- f"**High-risk languages** (needing extra training attention): "
727
- f"`{'`, `'.join(high_risk)}`"
728
- )
729
 
730
- # --- Training Progress ---
731
- gr.Markdown("## Training Progress")
732
- gr.Markdown(
733
- "*Key metrics tracked during Stage 2 distillation. CKA measures representation "
734
- "alignment between student and teacher layers. Dashed lines indicate targets.*"
735
- )
736
  gr.Plot(value=make_training_chart())
737
 
738
- # --- Footer ---
739
- gr.Markdown(
740
- f"---\n\n"
741
- f"**Wayy Research** -- Buffalo, NY -- Est. 2024 | "
742
- f"[GitHub](https://github.com/wayyresearch) | "
743
- f"Last updated: {last_updated}"
744
- )
 
 
 
 
 
 
 
 
 
745
 
746
 
747
  if __name__ == "__main__":
 
1
  """
2
+ Aetheris Playground CohereLabs x Wayy Research
3
+ A polyglot chat interface and benchmark dashboard for the Aetheris model.
4
 
5
+ Aetheris is a ~536M Hybrid Mamba-MoE multilingual model distilled from
6
+ CohereLabs/tiny-aya-global (3.35B). Chat in any of 67 languages and
7
+ the model responds in kind.
8
  """
9
 
10
  import json
11
  import os
12
+ import re
13
+ import sys
14
  from datetime import datetime
15
  from pathlib import Path
16
  from typing import Any
 
18
  import gradio as gr
19
  import numpy as np
20
  import pandas as pd
 
21
  import plotly.graph_objects as go
22
+ import torch
23
  from plotly.subplots import make_subplots
24
 
25
  # ---------------------------------------------------------------------------
26
+ # Constants & Data
27
  # ---------------------------------------------------------------------------
28
 
29
  BENCHMARK_PATH = Path(__file__).parent / "benchmark_results.json"
30
 
31
+ LANGUAGES = {
32
+ # ── Europe: Romance ──
33
+ "en": {"name": "English", "family": "Indo-European", "region": "Europe", "greeting": "Hello!", "flag": "🇬🇧"},
34
+ "fr": {"name": "French", "family": "Indo-European", "region": "Europe", "greeting": "Bonjour!", "flag": "🇫🇷"},
35
+ "es": {"name": "Spanish", "family": "Indo-European", "region": "Europe", "greeting": "¡Hola!", "flag": "🇪🇸"},
36
+ "pt": {"name": "Portuguese", "family": "Indo-European", "region": "Europe", "greeting": "Olá!", "flag": "🇧🇷"},
37
+ "it": {"name": "Italian", "family": "Indo-European", "region": "Europe", "greeting": "Ciao!", "flag": "🇮🇹"},
38
+ "ro": {"name": "Romanian", "family": "Indo-European", "region": "Europe", "greeting": "Bună!", "flag": "🇷🇴"},
39
+ "ca": {"name": "Catalan", "family": "Indo-European", "region": "Europe", "greeting": "Hola!", "flag": "🇪🇸"},
40
+ "gl": {"name": "Galician", "family": "Indo-European", "region": "Europe", "greeting": "Ola!", "flag": "🇪🇸"},
41
+ # ── Europe: Germanic ──
42
+ "de": {"name": "German", "family": "Indo-European", "region": "Europe", "greeting": "Hallo!", "flag": "🇩🇪"},
43
+ "nl": {"name": "Dutch", "family": "Indo-European", "region": "Europe", "greeting": "Hallo!", "flag": "🇳🇱"},
44
+ "da": {"name": "Danish", "family": "Indo-European", "region": "Europe", "greeting": "Hej!", "flag": "🇩🇰"},
45
+ "sv": {"name": "Swedish", "family": "Indo-European", "region": "Europe", "greeting": "Hej!", "flag": "🇸🇪"},
46
+ "no": {"name": "Norwegian", "family": "Indo-European", "region": "Europe", "greeting": "Hei!", "flag": "🇳🇴"},
47
+ # ── Europe: Slavic ──
48
+ "ru": {"name": "Russian", "family": "Indo-European", "region": "Europe", "greeting": "Привет!", "flag": "🇷🇺"},
49
+ "uk": {"name": "Ukrainian", "family": "Indo-European", "region": "Europe", "greeting": "Привіт!", "flag": "🇺🇦"},
50
+ "pl": {"name": "Polish", "family": "Indo-European", "region": "Europe", "greeting": "Cześć!", "flag": "🇵🇱"},
51
+ "cs": {"name": "Czech", "family": "Indo-European", "region": "Europe", "greeting": "Ahoj!", "flag": "🇨🇿"},
52
+ "sk": {"name": "Slovak", "family": "Indo-European", "region": "Europe", "greeting": "Ahoj!", "flag": "🇸🇰"},
53
+ "hr": {"name": "Croatian", "family": "Indo-European", "region": "Europe", "greeting": "Bok!", "flag": "🇭🇷"},
54
+ "sr": {"name": "Serbian", "family": "Indo-European", "region": "Europe", "greeting": "Здраво!", "flag": "🇷🇸"},
55
+ "sl": {"name": "Slovenian", "family": "Indo-European", "region": "Europe", "greeting": "Živjo!", "flag": "🇸🇮"},
56
+ "bg": {"name": "Bulgarian", "family": "Indo-European", "region": "Europe", "greeting": "Здравей!", "flag": "🇧🇬"},
57
+ # ── Europe: Baltic ──
58
+ "lv": {"name": "Latvian", "family": "Indo-European", "region": "Europe", "greeting": "Sveiki!", "flag": "🇱🇻"},
59
+ "lt": {"name": "Lithuanian", "family": "Indo-European", "region": "Europe", "greeting": "Labas!", "flag": "🇱🇹"},
60
+ # ── Europe: Other ──
61
+ "el": {"name": "Greek", "family": "Indo-European", "region": "Europe", "greeting": "Γεια σου!", "flag": "🇬🇷"},
62
+ "et": {"name": "Estonian", "family": "Uralic", "region": "Europe", "greeting": "Tere!", "flag": "🇪🇪"},
63
+ "fi": {"name": "Finnish", "family": "Uralic", "region": "Europe", "greeting": "Hei!", "flag": "🇫🇮"},
64
+ "hu": {"name": "Hungarian", "family": "Uralic", "region": "Europe", "greeting": "Szia!", "flag": "🇭🇺"},
65
+ "eu": {"name": "Basque", "family": "Language isolate", "region": "Europe", "greeting": "Kaixo!", "flag": "🇪🇸"},
66
+ "cy": {"name": "Welsh", "family": "Indo-European", "region": "Europe", "greeting": "Helo!", "flag": "🏴\U000E0067\U000E0062\U000E0077\U000E006C\U000E0073\U000E007F"},
67
+ "ga": {"name": "Irish", "family": "Indo-European", "region": "Europe", "greeting": "Dia duit!", "flag": "🇮🇪"},
68
+ "mt": {"name": "Maltese", "family": "Afroasiatic", "region": "Europe", "greeting": "Merħba!", "flag": "🇲🇹"},
69
+ # ── Middle East ──
70
+ "ar": {"name": "Arabic", "family": "Afroasiatic", "region": "Middle East", "greeting": "مرحبا!", "flag": "🇸🇦"},
71
+ "fa": {"name": "Persian", "family": "Indo-European", "region": "Middle East", "greeting": "سلام!", "flag": "🇮🇷"},
72
+ "he": {"name": "Hebrew", "family": "Afroasiatic", "region": "Middle East", "greeting": "!שלום", "flag": "🇮🇱"},
73
+ "tr": {"name": "Turkish", "family": "Turkic", "region": "Middle East", "greeting": "Merhaba!", "flag": "🇹🇷"},
74
+ # ── South Asia ──
75
+ "hi": {"name": "Hindi", "family": "Indo-European", "region": "South Asia", "greeting": "नमस्ते!", "flag": "🇮🇳"},
76
+ "ur": {"name": "Urdu", "family": "Indo-European", "region": "South Asia", "greeting": "!ہیلو", "flag": "🇵🇰"},
77
+ "bn": {"name": "Bengali", "family": "Indo-European", "region": "South Asia", "greeting": "হ্যালো!", "flag": "🇧🇩"},
78
+ "mr": {"name": "Marathi", "family": "Indo-European", "region": "South Asia", "greeting": "नमस्कार!", "flag": "🇮🇳"},
79
+ "gu": {"name": "Gujarati", "family": "Indo-European", "region": "South Asia", "greeting": "નમસ્તે!", "flag": "🇮🇳"},
80
+ "pa": {"name": "Punjabi", "family": "Indo-European", "region": "South Asia", "greeting": "ਸਤ ਸ੍ਰੀ ਅਕਾਲ!", "flag": "🇮🇳"},
81
+ "ne": {"name": "Nepali", "family": "Indo-European", "region": "South Asia", "greeting": "नमस्ते!", "flag": "🇳🇵"},
82
+ "ta": {"name": "Tamil", "family": "Dravidian", "region": "South Asia", "greeting": "வணக்கம்!", "flag": "🇮🇳"},
83
+ "te": {"name": "Telugu", "family": "Dravidian", "region": "South Asia", "greeting": "హలో!", "flag": "🇮🇳"},
84
+ # ── East Asia ──
85
+ "zh": {"name": "Chinese", "family": "Sino-Tibetan", "region": "East Asia", "greeting": "你好!", "flag": "🇨🇳"},
86
+ "ja": {"name": "Japanese", "family": "Japonic", "region": "East Asia", "greeting": "こんにちは!", "flag": "🇯🇵"},
87
+ "ko": {"name": "Korean", "family": "Koreanic", "region": "East Asia", "greeting": "안녕하세요!", "flag": "🇰🇷"},
88
+ # ── Southeast Asia ──
89
+ "id": {"name": "Indonesian", "family": "Austronesian", "region": "Southeast Asia", "greeting": "Halo!", "flag": "🇮🇩"},
90
+ "ms": {"name": "Malay", "family": "Austronesian", "region": "Southeast Asia", "greeting": "Hai!", "flag": "🇲🇾"},
91
+ "tl": {"name": "Tagalog", "family": "Austronesian", "region": "Southeast Asia", "greeting": "Kamusta!", "flag": "🇵🇭"},
92
+ "jv": {"name": "Javanese", "family": "Austronesian", "region": "Southeast Asia", "greeting": "Halo!", "flag": "🇮🇩"},
93
+ "vi": {"name": "Vietnamese", "family": "Austroasiatic", "region": "Southeast Asia", "greeting": "Xin chào!", "flag": "🇻🇳"},
94
+ "km": {"name": "Khmer", "family": "Austroasiatic", "region": "Southeast Asia", "greeting": "សួស្តី!", "flag": "🇰🇭"},
95
+ "th": {"name": "Thai", "family": "Kra-Dai", "region": "Southeast Asia", "greeting": "สวัสดี!", "flag": "🇹🇭"},
96
+ "lo": {"name": "Lao", "family": "Kra-Dai", "region": "Southeast Asia", "greeting": "ສະບາຍດີ!", "flag": "🇱🇦"},
97
+ "my": {"name": "Burmese", "family": "Sino-Tibetan", "region": "Southeast Asia", "greeting": "မင်္ဂလာပါ!", "flag": "🇲🇲"},
98
+ # ── Africa ──
99
+ "am": {"name": "Amharic", "family": "Afroasiatic", "region": "Africa", "greeting": "ሰላም!", "flag": "🇪🇹"},
100
+ "ha": {"name": "Hausa", "family": "Afroasiatic", "region": "Africa", "greeting": "Sannu!", "flag": "🇳🇬"},
101
+ "sw": {"name": "Swahili", "family": "Niger-Congo", "region": "Africa", "greeting": "Habari!", "flag": "🇰🇪"},
102
+ "ig": {"name": "Igbo", "family": "Niger-Congo", "region": "Africa", "greeting": "Nnọọ!", "flag": "🇳🇬"},
103
+ "yo": {"name": "Yoruba", "family": "Niger-Congo", "region": "Africa", "greeting": "Ẹ kú!", "flag": "🇳🇬"},
104
+ "zu": {"name": "Zulu", "family": "Niger-Congo", "region": "Africa", "greeting": "Sawubona!", "flag": "🇿🇦"},
105
+ "xh": {"name": "Xhosa", "family": "Niger-Congo", "region": "Africa", "greeting": "Molo!", "flag": "🇿🇦"},
106
+ "sn": {"name": "Shona", "family": "Niger-Congo", "region": "Africa", "greeting": "Mhoro!", "flag": "🇿🇼"},
107
+ "wo": {"name": "Wolof", "family": "Niger-Congo", "region": "Africa", "greeting": "Nanga def!", "flag": "🇸🇳"},
108
+ "mg": {"name": "Malagasy", "family": "Austronesian", "region": "Africa", "greeting": "Manao ahoana!", "flag": "🇲🇬"},
109
  }
110
 
111
+ # Region display order
112
+ REGION_ORDER = ["Europe", "Middle East", "South Asia", "East Asia", "Southeast Asia", "Africa"]
113
+
114
+ POLYGLOT_EXAMPLES = [
115
+ ["Tell me about the history of mathematics", "en"],
116
+ ["Explica la teoría de la relatividad en términos sencillos", "es"],
117
+ ["गुरुत्वाकर्षण क्या है? सरल शब्दों में समझाइए।", "hi"],
118
+ ["量子计算的基本原理是什么?", "zh"],
119
+ ["اكتب قصة قصيرة عن رحلة إلى الفضاء", "ar"],
120
+ ["Eleza umuhimu wa maji kwa maisha", "sw"],
121
+ ["Yapay zeka dünyayı nasıl değiştirecek?", "tr"],
122
+ ["日本の四季について教えてください", "ja"],
123
+ ["Jelaskan tentang perubahan iklim", "id"],
124
+ ["సౌర వ్యవస్థ గురించి చెప్పండి", "te"],
125
+ ["Parlez-moi de la Révolution française", "fr"],
126
+ ["Erkläre mir die Quantenmechanik", "de"],
127
+ ]
128
+
129
+ # Plotly colors matching the Cohere+Wayy blend
130
+ CORAL = "#E8553D"
131
+ CORAL_LIGHT = "#FF8A73"
132
+ BLUE = "#4C6EE6"
133
+ GREEN = "#2AAA5B"
134
+ AMBER = "#E8943D"
135
+ INDIGO = "#4338ca"
136
+ TEXT_PRIMARY = "#1A1A2E"
137
+ TEXT_SECONDARY = "#555570"
138
 
139
 
140
  def load_benchmarks() -> dict[str, Any]:
 
141
  if BENCHMARK_PATH.exists():
142
  with open(BENCHMARK_PATH) as f:
143
  return json.load(f)
 
147
  BENCHMARKS = load_benchmarks()
148
 
149
  # ---------------------------------------------------------------------------
150
+ # Model
151
  # ---------------------------------------------------------------------------
152
 
153
  model = None
154
  tokenizer = None
155
+ MODEL_LOADED = False
156
+ VOCAB_MAPPING = None # old_token_id -> new_token_id mapping for pruned vocab
157
 
158
 
159
+ def load_model():
160
+ """Load the pruned Aetheris model from HuggingFace."""
161
+ global model, tokenizer, MODEL_LOADED, VOCAB_MAPPING
162
+ if MODEL_LOADED:
163
  return model, tokenizer
164
 
165
+ REPO_ID = "wayyresearch/aetheris"
166
  try:
167
+ from huggingface_hub import snapshot_download
168
+ from transformers import AutoTokenizer
169
+
170
+ print(f"Downloading {REPO_ID}...")
171
+ local_dir = snapshot_download(REPO_ID)
172
+ print(f"Downloaded to {local_dir}")
173
+
174
+ # Add model code to path so we can import aetheris
175
+ sys.path.insert(0, local_dir)
176
+ from aetheris.config import AetherisConfig
177
+ from aetheris.model import HybridMambaMoE
178
+
179
+ # Load config and create model
180
+ config = AetherisConfig.from_yaml(os.path.join(local_dir, "config.yaml"))
181
+ model = HybridMambaMoE(config)
182
+
183
+ # Load weights
184
+ sd = torch.load(
185
+ os.path.join(local_dir, "pytorch_model.pt"),
186
+ map_location="cpu",
187
+ weights_only=True,
188
  )
189
+ model.load_state_dict(sd)
190
+ model.eval()
191
+ print(f"Model loaded: {sum(p.numel() for p in model.parameters())/1e6:.0f}M params")
192
+
193
+ # Load vocab mapping (old tokenizer IDs -> pruned model IDs)
194
+ mapping_path = os.path.join(local_dir, "vocab_mapping.json")
195
+ if os.path.exists(mapping_path):
196
+ with open(mapping_path) as f:
197
+ mapping_data = json.load(f)
198
+ keep_list = mapping_data["keep_list"]
199
+ VOCAB_MAPPING = {old_id: new_id for new_id, old_id in enumerate(keep_list)}
200
+ print(f"Vocab mapping loaded: {len(VOCAB_MAPPING)} tokens")
201
+ else:
202
+ VOCAB_MAPPING = None
203
+ print("No vocab mapping found, using direct token IDs")
204
+
205
+ # Load tokenizer (Aya tokenizer from teacher)
206
+ tokenizer = AutoTokenizer.from_pretrained("CohereForAI/aya-expanse-8b")
207
+ print(f"Tokenizer loaded: {len(tokenizer)} tokens")
208
+
209
+ MODEL_LOADED = True
210
  return model, tokenizer
211
+
212
  except Exception as e:
213
+ print(f"Could not load model: {e}")
214
+ import traceback
215
+ traceback.print_exc()
216
+ MODEL_LOADED = True # Don't retry
217
  return None, None
218
 
219
 
220
  # ---------------------------------------------------------------------------
221
+ # Demo mode — scripted multilingual responses while model trains
222
  # ---------------------------------------------------------------------------
223
 
224
+ # Each entry: (native_response, english_translation)
225
+ # English entries have translation=None since they're already in English
226
+ DEMO_RESPONSES: dict[str, list[tuple[str, str | None]]] = {
227
+ "ar": [
228
+ ("مرحباً! أنا إيثريس، نموذج لغوي متعدد اللغات. أنا قيد التدريب حالياً — سأكون متاحاً بالكامل قريباً!",
229
+ "Hello! I'm Aetheris, a multilingual language model. I'm currently in training — I'll be fully available soon!"),
230
+ ("هذا سؤال رائع! حالياً أعمل في وضع تجريبي بينما يتم تدريب نموذجي الكامل. تفضل بزيارة تبويب المعايير لرؤية مقارنة أدائي.",
231
+ "That's a great question! I'm currently working in demo mode while my full model is being trained. Visit the benchmarks tab to see my performance comparison."),
232
+ ],
233
+ "zh": [
234
+ ("你好!我是 Aetheris,一个多语言混合 Mamba-MoE 模型。我目前正在训练中,很快就会完全上线!",
235
+ "Hello! I'm Aetheris, a multilingual hybrid Mamba-MoE model. I'm currently training and will be fully online soon!"),
236
+ ("这是个好问题!我目前处于演示模式。请查看基准测试标签,了解我与 tiny-aya-global 的性能对比。",
237
+ "Good question! I'm currently in demo mode. Check the benchmarks tab to see how I compare with tiny-aya-global."),
238
+ ],
239
+ "ja": [
240
+ ("こんにちは!Aetherisです。多言語ハイブリッドMamba-MoEモデルです。現在トレーニング中ですが、まもなく完全に利用可能になります!",
241
+ "Hello! I'm Aetheris, a multilingual hybrid Mamba-MoE model. I'm currently training but will be fully available soon!"),
242
+ ("素晴らしい質問ですね!現在はデモモードで動作しています。ベンチマークタブで性能比較をご覧ください。",
243
+ "Great question! I'm currently running in demo mode. Check the benchmarks tab for performance comparisons."),
244
+ ],
245
+ "hi": [
246
+ ("नमस्ते! मैं एथेरिस हूँ, एक बहुभाषी हाइब्रिड Mamba-MoE मॉडल। मैं अभी प्रशिक्षण में हूँ — जल्द ही पूरी तरह उपलब्ध होऊँगा!",
247
+ "Hello! I'm Aetheris, a multilingual hybrid Mamba-MoE model. I'm currently in training — I'll be fully available soon!"),
248
+ ("बढ़िया सवाल! मैं अभी डेमो मोड में हूँ। मेरे प्रदर्शन की तुलना देखने के लिए बेंचमार्क टैब देखें।",
249
+ "Great question! I'm in demo mode right now. Check the benchmarks tab to see my performance comparison."),
250
+ ],
251
+ "es": [
252
+ ("¡Hola! Soy Aetheris, un modelo multilingüe híbrido Mamba-MoE. Estoy en entrenamiento — ¡pronto estaré completamente disponible!",
253
+ "Hello! I'm Aetheris, a hybrid Mamba-MoE multilingual model. I'm in training — I'll be fully available soon!"),
254
+ ("¡Buena pregunta! Ahora mismo estoy en modo demo mientras mi modelo completo se entrena. Visita la pestaña de benchmarks para ver cómo me comparo con tiny-aya-global.",
255
+ "Good question! I'm currently in demo mode while my full model trains. Visit the benchmarks tab to see how I compare with tiny-aya-global."),
256
+ ],
257
+ "fr": [
258
+ ("Bonjour ! Je suis Aetheris, un modèle multilingue hybride Mamba-MoE. Je suis en cours d'entraînement — je serai bientôt pleinement disponible !",
259
+ "Hello! I'm Aetheris, a hybrid Mamba-MoE multilingual model. I'm currently training — I'll be fully available soon!"),
260
+ ("Excellente question ! Je suis actuellement en mode démo. Consultez l'onglet benchmarks pour voir mes performances comparées à tiny-aya-global.",
261
+ "Excellent question! I'm currently in demo mode. Check the benchmarks tab to see my performance compared to tiny-aya-global."),
262
+ ],
263
+ "de": [
264
+ ("Hallo! Ich bin Aetheris, ein mehrsprachiges Hybrid-Mamba-MoE-Modell. Ich werde gerade trainiert — bald bin ich vollständig verfügbar!",
265
+ "Hello! I'm Aetheris, a multilingual hybrid Mamba-MoE model. I'm currently being trained — I'll be fully available soon!"),
266
+ ("Gute Frage! Ich bin gerade im Demo-Modus. Schauen Sie sich den Benchmark-Tab an, um meine Leistung im Vergleich zu tiny-aya-global zu sehen.",
267
+ "Good question! I'm currently in demo mode. Check the benchmarks tab to see my performance compared to tiny-aya-global."),
268
+ ],
269
+ "ko": [
270
+ ("안녕하세요! 저는 Aetheris입니다. 다국어 하이브리드 Mamba-MoE 모델이에요. 현재 훈련 중이며 곧 완전히 사용 가능해질 거예요!",
271
+ "Hello! I'm Aetheris, a multilingual hybrid Mamba-MoE model. I'm currently training and will be fully available soon!"),
272
+ ("좋은 질문이에요! 지금은 데모 모드로 작동 중입니다. 벤치마크 탭에서 성능 비교를 확인해 보세요.",
273
+ "Good question! I'm currently running in demo mode. Check the benchmarks tab for performance comparisons."),
274
+ ],
275
+ "tr": [
276
+ ("Merhaba! Ben Aetheris, çok dilli hibrit bir Mamba-MoE modeliyim. Şu anda eğitim aşamasındayım — yakında tamamen hazır olacağım!",
277
+ "Hello! I'm Aetheris, a multilingual hybrid Mamba-MoE model. I'm currently in training — I'll be fully ready soon!"),
278
+ ("Harika soru! Şu anda demo modundayım. Benchmark sekmesine göz atarak performans karşılaştırmasını görebilirsiniz.",
279
+ "Great question! I'm currently in demo mode. Check the benchmarks tab to see performance comparisons."),
280
+ ],
281
+ "sw": [
282
+ ("Habari! Mimi ni Aetheris, modeli ya lugha nyingi ya Mamba-MoE. Ninafunzwa sasa — nitakuwa tayari hivi karibuni!",
283
+ "Hello! I'm Aetheris, a multilingual Mamba-MoE model. I'm currently being trained — I'll be ready soon!"),
284
+ ("Swali zuri! Niko katika hali ya maonyesho sasa hivi. Angalia kichupo cha vigezo vya utendaji kuona jinsi ninavyolinganishwa.",
285
+ "Good question! I'm in demo mode right now. Check the benchmarks tab to see how I compare."),
286
+ ],
287
+ "id": [
288
+ ("Halo! Saya Aetheris, model multibahasa hibrida Mamba-MoE. Saya sedang dalam pelatihan — segera akan tersedia sepenuhnya!",
289
+ "Hello! I'm Aetheris, a hybrid Mamba-MoE multilingual model. I'm currently in training — I'll be fully available soon!"),
290
+ ("Pertanyaan bagus! Saat ini saya dalam mode demo. Lihat tab benchmark untuk perbandingan performa saya.",
291
+ "Good question! I'm currently in demo mode. Check the benchmarks tab for my performance comparison."),
292
+ ],
293
+ "te": [
294
+ ("హలో! నేను ఏథెరిస్, బహుభాషా హైబ్రిడ్ Mamba-MoE మోడల్. ప్రస్తుతం శిక్షణలో ఉన్నాను — త్వరలో పూర్తిగా అందుబాటులో ఉంటాను!",
295
+ "Hello! I'm Aetheris, a multilingual hybrid Mamba-MoE model. I'm currently in training — I'll be fully available soon!"),
296
+ ("మంచి ప్రశ్న! ప్రస్తుతం డెమో మోడ్‌లో ఉన్నాను. బెంచ్‌మార్క్ ట్యాబ్‌లో పనితీరు పోలికను చూడండి.",
297
+ "Good question! I'm currently in demo mode. Check the benchmarks tab for performance comparisons."),
298
+ ],
299
+ "pt": [
300
+ ("Olá! Eu sou o Aetheris, um modelo multilíngue híbrido Mamba-MoE. Estou em treinamento — em breve estarei totalmente disponível!",
301
+ "Hello! I'm Aetheris, a hybrid Mamba-MoE multilingual model. I'm in training — I'll be fully available soon!"),
302
+ ("Boa pergunta! Estou no modo demo agora. Confira a aba de benchmarks para ver como me comparo com o tiny-aya-global.",
303
+ "Good question! I'm in demo mode now. Check the benchmarks tab to see how I compare with tiny-aya-global."),
304
+ ],
305
+ "ru": [
306
+ ("Привет! Я Aetheris, мультиязычная гибридная модель Mamba-MoE. Сейчас я на стадии обучения — скоро буду полностью доступен!",
307
+ "Hello! I'm Aetheris, a multilingual hybrid Mamba-MoE model. I'm currently in training — I'll be fully available soon!"),
308
+ ("Отличный вопрос! Сейчас я работаю в демо-режиме. Загляните на вкладку бенчмарков, чтобы увидеть сравнение производительности.",
309
+ "Great question! I'm currently working in demo mode. Check the benchmarks tab to see performance comparisons."),
310
+ ],
311
+ "en": [
312
+ ("Hello! I'm Aetheris, a hybrid Mamba-MoE multilingual model with 800M parameters, "
313
+ "distilled from CohereLabs' tiny-aya-global (3.35B) with 4.2x compression.\n\n"
314
+ "I'm currently training on RunPod — Stage 1 (CKA layer alignment) is in progress. "
315
+ "Once complete, I'll be fully deployed here for live inference!\n\n"
316
+ "In the meantime, check out the **Benchmarks** tab to see projected performance across "
317
+ "10 languages, throughput comparisons, and our multilingual equity metrics.",
318
+ None),
319
+
320
+ ("Great question! Right now I'm in demo mode — my full model is being trained using a "
321
+ "3-stage distillation pipeline:\n\n"
322
+ "1. **Stage 1** — CKA-guided layer alignment (in progress)\n"
323
+ "2. **Stage 2** — KL divergence distillation with per-language tracking\n"
324
+ "3. **Stage 3** — Supervised fine-tuning on aya_collection\n\n"
325
+ "Key innovations: SSM 10x LR boost (compensates 27x gradient imbalance), "
326
+ "SVD split for MoE expert initialization, and DES tracking for multilingual equity.\n\n"
327
+ "Try the Benchmarks tab for detailed comparisons!",
328
+ None),
329
+
330
+ ("I appreciate your patience! My architecture is unique — I interleave "
331
+ "**Mamba SSM layers** (for efficient sequence modeling) with **MoE layers** "
332
+ "(for capacity without parameter explosion). This hybrid design achieves "
333
+ "3.1x faster inference and 4x memory reduction compared to the teacher model.\n\n"
334
+ "Once training completes, you'll be able to chat with me in 67 languages right here!",
335
+ None),
336
+ ],
337
+ }
338
+
339
+ # Counters for cycling through responses per language
340
+ _demo_counters: dict[str, int] = {}
341
+
342
+
343
+ def _detect_lang(text: str) -> str:
344
+ """Simple script-based language detection for demo mode."""
345
+ import unicodedata
346
+ scripts: dict[str, int] = {}
347
+ for ch in text:
348
+ if ch.isalpha():
349
+ script = unicodedata.script(ch) if hasattr(unicodedata, 'script') else "LATIN"
350
+ scripts[script] = scripts.get(script, 0) + 1
351
+
352
+ # Fallback: check for non-ASCII character ranges
353
+ for ch in text:
354
+ cp = ord(ch)
355
+ if 0x0600 <= cp <= 0x06FF:
356
+ return "ar"
357
+ if 0x4E00 <= cp <= 0x9FFF:
358
+ return "zh"
359
+ if 0x3040 <= cp <= 0x30FF or 0x31F0 <= cp <= 0x31FF:
360
+ return "ja"
361
+ if 0x0900 <= cp <= 0x097F:
362
+ return "hi"
363
+ if 0xAC00 <= cp <= 0xD7AF:
364
+ return "ko"
365
+ if 0x0C00 <= cp <= 0x0C7F:
366
+ return "te"
367
+ if 0x0E00 <= cp <= 0x0E7F:
368
+ return "th"
369
+ if 0x0980 <= cp <= 0x09FF:
370
+ return "bn"
371
+ if 0x0400 <= cp <= 0x04FF:
372
+ return "ru"
373
+ if 0x0A80 <= cp <= 0x0AFF:
374
+ return "gu"
375
+
376
+ # Latin script — check for language-specific keywords
377
+ lower = text.lower()
378
+ if any(w in lower for w in ["¿", "¡", "está", "hola", "qué", "cómo", "también"]):
379
+ return "es"
380
+ if any(w in lower for w in ["bonjour", "merci", "comment", "pourquoi", "français"]):
381
+ return "fr"
382
+ if any(w in lower for w in ["hallo", "danke", "warum", "erklär", "deutsch"]):
383
+ return "de"
384
+ if any(w in lower for w in ["olá", "obrigad", "como", "português"]):
385
+ return "pt"
386
+ if any(w in lower for w in ["merhaba", "nasıl", "nedir", "türk"]):
387
+ return "tr"
388
+ if any(w in lower for w in ["habari", "nini", "kwa", "swahili"]):
389
+ return "sw"
390
+ if any(w in lower for w in ["halo", "apa", "jelaskan", "bagaimana", "indonesia"]):
391
+ return "id"
392
+
393
+ return "en"
394
+
395
+
396
+ def _demo_respond(message: str) -> tuple[str, str | None, str]:
397
+ """Generate a demo response in the detected language.
398
+
399
+ Returns (native_text, english_translation_or_None, lang_code).
400
+ """
401
+ lang = _detect_lang(message)
402
+ responses = DEMO_RESPONSES.get(lang, DEMO_RESPONSES["en"])
403
+ idx = _demo_counters.get(lang, 0) % len(responses)
404
+ _demo_counters[lang] = idx + 1
405
+ native, translation = responses[idx]
406
+ return native, translation, lang
407
 
408
+
409
+ # ---------------------------------------------------------------------------
410
+ # Chat logic
411
+ # ---------------------------------------------------------------------------
412
+
413
+
414
+ def _map_token_ids(input_ids: list[int]) -> "torch.Tensor":
415
+ """Map tokenizer IDs to pruned model IDs using vocab mapping."""
416
+ if VOCAB_MAPPING is None:
417
+ return torch.tensor([input_ids])
418
+ mapped = []
419
+ for tid in input_ids:
420
+ if tid in VOCAB_MAPPING:
421
+ mapped.append(VOCAB_MAPPING[tid])
422
+ else:
423
+ # Map unknown tokens to UNK (token 0 in pruned vocab)
424
+ mapped.append(0)
425
+ return torch.tensor([mapped])
426
+
427
+
428
+ def _unmap_token_id(model_id: int) -> int:
429
+ """Map pruned model ID back to original tokenizer ID for decoding."""
430
+ if VOCAB_MAPPING is None:
431
+ return model_id
432
+ # Build reverse mapping lazily
433
+ if not hasattr(_unmap_token_id, "_reverse"):
434
+ _unmap_token_id._reverse = {v: k for k, v in VOCAB_MAPPING.items()}
435
+ return _unmap_token_id._reverse.get(model_id, 0)
436
+
437
+
438
+ def _generate_text(m, tok, prompt: str, max_tokens: int, temperature: float) -> str:
439
+ """Autoregressive text generation with vocab mapping support."""
440
+ import torch
441
+
442
+ # Tokenize
443
+ input_ids = tok.encode(prompt, add_special_tokens=False)
444
+ # Truncate to fit context
445
+ if len(input_ids) > 1024:
446
+ input_ids = input_ids[-1024:]
447
+
448
+ # Map to pruned vocab
449
+ ids_tensor = _map_token_ids(input_ids)
450
+
451
+ generated_ids = []
452
+ eos_id = VOCAB_MAPPING.get(tok.eos_token_id, tok.eos_token_id) if VOCAB_MAPPING else tok.eos_token_id
453
+
454
+ with torch.no_grad():
455
+ for _ in range(max_tokens):
456
+ out = m(ids_tensor)
457
+ logits = out["logits"][:, -1, :] # [1, vocab_size]
458
+
459
+ if temperature > 0:
460
+ logits = logits / temperature
461
+ probs = torch.softmax(logits, dim=-1)
462
+ # Top-p sampling
463
+ sorted_probs, sorted_indices = torch.sort(probs, descending=True)
464
+ cumulative = torch.cumsum(sorted_probs, dim=-1)
465
+ mask = cumulative - sorted_probs > 0.9
466
+ sorted_probs[mask] = 0
467
+ sorted_probs = sorted_probs / sorted_probs.sum()
468
+ next_token = sorted_indices[0, torch.multinomial(sorted_probs[0], 1)]
469
+ else:
470
+ next_token = logits.argmax(dim=-1).squeeze()
471
+
472
+ next_id = next_token.item()
473
+ if next_id == eos_id:
474
+ break
475
+
476
+ generated_ids.append(next_id)
477
+ ids_tensor = torch.cat([ids_tensor, next_token.unsqueeze(0).unsqueeze(0)], dim=1)
478
+
479
+ # Keep context window manageable
480
+ if ids_tensor.shape[1] > 1024:
481
+ ids_tensor = ids_tensor[:, -1024:]
482
+
483
+ # Map back to original tokenizer IDs for decoding
484
+ original_ids = [_unmap_token_id(gid) for gid in generated_ids]
485
+ return tok.decode(original_ids, skip_special_tokens=True).strip()
486
+
487
+
488
+ def chat_respond(message: str, history: list[dict], temperature: float, max_tokens: int):
489
+ """Handle a chat message with multilingual awareness."""
490
+ if not message.strip():
491
+ return history, ""
492
 
493
  m, tok = load_model()
494
 
495
+ # Model not available
496
  if m is None or tok is None:
497
+ response = (
498
+ "Model is currently loading or unavailable. "
499
+ "Please try again in a moment — the ~536M parameter model "
500
+ "may take a minute to download on first load."
 
501
  )
502
+ history.append({"role": "user", "content": message})
503
+ history.append({"role": "assistant", "content": response})
504
+ return history, ""
505
 
506
  try:
507
  import torch
508
 
509
+ # Build conversation context from history
510
+ context_parts = []
511
+ for msg in history[-6:]: # Keep last 6 messages for context
512
+ role = msg["role"]
513
+ content = msg["content"]
514
+ if role == "user":
515
+ context_parts.append(f"User: {content}")
516
+ else:
517
+ clean = content.split("\n\n---\n")[0] if "\n\n---\n" in content else content
518
+ context_parts.append(f"Assistant: {clean}")
519
+
520
+ context_parts.append(f"User: {message}")
521
+ context_parts.append("Assistant:")
522
+ prompt = "\n".join(context_parts)
523
+
524
+ generated = _generate_text(m, tok, prompt, int(max_tokens), float(temperature))
525
+
526
+ # Clean up: stop at next "User:" if present
527
+ if "User:" in generated:
528
+ generated = generated[:generated.index("User:")].strip()
529
+
530
+ # Auto-translate non-English responses for English-speaking users
531
+ lang = _detect_lang(message)
532
+ if lang != "en" and generated:
533
+ try:
534
+ translate_prompt = f"Translate the following to English:\n{generated}\nEnglish translation:"
535
+ translation = _generate_text(m, tok, translate_prompt, int(max_tokens), 0.3)
536
+ if "User:" in translation:
537
+ translation = translation[:translation.index("User:")].strip()
538
+ lang_name = LANGUAGES.get(lang, {}).get("name", lang)
539
+ generated = (
540
+ f"{generated}\n\n"
541
+ f"---\n"
542
+ f"**English translation** ({lang_name}):\n"
543
+ f"*{translation}*"
544
+ )
545
+ except Exception:
546
+ pass # If translation fails, just show the original
547
 
 
 
548
  except Exception as e:
549
+ generated = f"Generation error: {e}"
550
+
551
+ history.append({"role": "user", "content": message})
552
+ history.append({"role": "assistant", "content": generated})
553
+ return history, ""
554
+
555
+
556
+ def load_example(example_text: str):
557
+ """Load an example into the chat input."""
558
+ return example_text
559
 
560
 
561
  # ---------------------------------------------------------------------------
562
+ # Benchmark visualizations (light theme)
563
  # ---------------------------------------------------------------------------
564
 
565
+ PLOT_LAYOUT = dict(
566
+ template="plotly_white",
567
+ font=dict(family="Inter, Outfit, Helvetica Neue, sans-serif", color=TEXT_PRIMARY),
568
+ paper_bgcolor="rgba(0,0,0,0)",
569
+ plot_bgcolor="rgba(255,251,247,0.5)",
570
+ margin=dict(t=60, b=50, l=60, r=30),
571
+ )
572
+
573
 
574
  def make_model_info_table() -> pd.DataFrame:
 
575
  info = BENCHMARKS.get("model_info", {})
576
  teacher = info.get("teacher", {})
577
  student = info.get("student", {})
 
578
  rows = [
579
+ ("Organization", "CohereLabs", "Wayy Research"),
580
+ ("Model", "tiny-aya-global", "Aetheris"),
581
  ("Parameters", f"{teacher.get('params_m', 3350)}M", f"{student.get('params_m', 800)}M"),
582
+ ("Architecture", teacher.get("architecture", "Transformer (Dense)"), "Hybrid Mamba-MoE"),
583
  ("Layers", str(teacher.get("layers", 36)), str(student.get("layers", 24))),
584
+ ("Hidden Dim", str(teacher.get("hidden_dim", 2048)), str(student.get("hidden_dim", 1024))),
585
+ ("Attention", "GQA (16 query / 4 kv heads)", "SSM (Mamba) + MoE routing"),
586
  ("Experts", "N/A (dense)", f"{student.get('num_experts', 4)} (top-{student.get('top_k', 1)})"),
587
+ ("Vocab", f"{teacher.get('vocab_size', 262144):,}", f"{student.get('vocab_size', 262144):,}"),
588
+ ("Languages", "67", "67 (inherited)"),
 
589
  ("Compression", "1.0x (baseline)", f"{student.get('compression_ratio', 4.2)}x"),
590
  ]
591
+ return pd.DataFrame(rows, columns=["", "Teacher", "Student (Aetheris)"])
 
592
 
593
 
594
  def make_benchmark_chart(benchmark_key: str, title: str) -> go.Figure:
 
595
  data = BENCHMARKS.get(benchmark_key, {})
596
  langs = data.get("languages", {})
597
  avg = data.get("average", {})
598
 
599
  lang_codes = list(langs.keys())
600
+ lang_names = [LANGUAGES.get(lc, {}).get("name", lc) for lc in lang_codes]
601
  teacher_scores = [langs[lc]["teacher"] for lc in lang_codes]
602
  student_scores = [langs[lc]["student"] for lc in lang_codes]
603
 
604
+ lang_names.append("Average")
 
605
  teacher_scores.append(avg.get("teacher", 0))
606
  student_scores.append(avg.get("student", 0))
607
 
608
  fig = go.Figure()
609
  fig.add_trace(go.Bar(
610
  name="Teacher (tiny-aya-global)",
611
+ x=lang_names, y=teacher_scores,
612
+ marker_color=BLUE, marker_line_width=0,
 
613
  text=[f"{s:.1f}" for s in teacher_scores],
614
+ textposition="outside", textfont=dict(size=10),
 
615
  ))
616
  fig.add_trace(go.Bar(
617
  name="Student (Aetheris)",
618
+ x=lang_names, y=student_scores,
619
+ marker_color=CORAL, marker_line_width=0,
 
620
  text=[f"{s:.1f}" for s in student_scores],
621
+ textposition="outside", textfont=dict(size=10),
 
622
  ))
623
 
 
624
  fig.update_layout(
625
+ **PLOT_LAYOUT,
626
+ title=dict(text=title, font=dict(size=16, color=TEXT_PRIMARY)),
627
  xaxis_title="Language",
628
+ yaxis_title=data.get("metric", "Accuracy (%)"),
629
  barmode="group",
630
+ height=420,
631
+ legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1,
632
+ bgcolor="rgba(255,251,247,0.8)"),
 
633
  )
 
634
  return fig
635
 
636
 
637
  def make_retention_chart() -> go.Figure:
 
638
  mgsm = BENCHMARKS.get("mgsm", {}).get("languages", {})
639
  xcopa = BENCHMARKS.get("xcopa", {}).get("languages", {})
640
 
641
  lang_codes = list(mgsm.keys())
642
+ lang_names = [LANGUAGES.get(lc, {}).get("name", lc) for lc in lang_codes]
643
 
644
+ mgsm_ret = [mgsm[lc]["student"] / mgsm[lc]["teacher"] * 100 if mgsm[lc]["teacher"] > 0 else 0 for lc in lang_codes]
645
+ xcopa_ret = [xcopa[lc]["student"] / xcopa[lc]["teacher"] * 100 if lc in xcopa and xcopa[lc]["teacher"] > 0 else 0 for lc in lang_codes]
 
 
 
 
 
 
646
 
647
  fig = go.Figure()
648
+ fig.add_trace(go.Scatter(x=lang_names, y=mgsm_ret, mode="lines+markers+text", name="mGSM",
649
+ line=dict(color=BLUE, width=2.5), marker=dict(size=8),
650
+ text=[f"{r:.0f}%" for r in mgsm_ret], textposition="top center", textfont=dict(size=9)))
651
+ fig.add_trace(go.Scatter(x=lang_names, y=xcopa_ret, mode="lines+markers+text", name="XCOPA",
652
+ line=dict(color=CORAL, width=2.5), marker=dict(size=8),
653
+ text=[f"{r:.0f}%" for r in xcopa_ret], textposition="bottom center", textfont=dict(size=9)))
654
+ fig.add_hline(y=80, line_dash="dash", line_color=AMBER,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
655
  annotation_text="80% retention target", annotation_position="top right")
656
 
657
+ fig.update_layout(**PLOT_LAYOUT, title=dict(text="Quality Retention: Student / Teacher (%)", font=dict(size=16)),
658
+ yaxis_title="Retention (%)", xaxis_title="Language", height=380,
659
+ yaxis=dict(range=[60, 105]),
660
+ legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1))
 
 
 
 
 
 
661
  return fig
662
 
663
 
664
  def make_throughput_chart() -> go.Figure:
 
665
  tp = BENCHMARKS.get("throughput", {})
666
  teacher = tp.get("teacher", {})
667
  student = tp.get("student", {})
668
 
669
+ metrics = ["Tokens/sec", "TTFT (ms)", "Memory (MB)"]
670
+ teacher_vals = [teacher.get("tokens_per_sec", 0), teacher.get("ttft_ms", 0), teacher.get("memory_mb", 0)]
671
+ student_vals = [student.get("tokens_per_sec", 0), student.get("ttft_ms", 0), student.get("memory_mb", 0)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
672
 
673
+ fig = make_subplots(rows=1, cols=3, subplot_titles=metrics, horizontal_spacing=0.1)
674
+ for i, (m, tv, sv) in enumerate(zip(metrics, teacher_vals, student_vals)):
675
+ fig.add_trace(go.Bar(x=["Teacher", "Student"], y=[tv, sv],
676
+ marker_color=[BLUE, CORAL], text=[str(tv), str(sv)],
677
+ textposition="outside", showlegend=False), row=1, col=i + 1)
 
 
 
 
 
 
 
 
 
 
 
678
 
679
+ fig.update_layout(**PLOT_LAYOUT, height=350,
680
+ title=dict(text="Inference Performance", font=dict(size=16)))
681
  return fig
682
 
683
 
684
  def make_equity_chart() -> go.Figure:
 
685
  equity = BENCHMARKS.get("equity", {})
686
  families = equity.get("language_families", {})
687
 
 
689
  des_vals = [families[n]["des"] for n in names]
690
  retention_vals = [families[n]["avg_retention"] * 100 for n in names]
691
 
 
 
 
 
 
 
 
 
 
 
692
  sorted_idx = np.argsort(des_vals)[::-1]
693
  sorted_names = [names[i] for i in sorted_idx]
694
  sorted_des = [des_vals[i] for i in sorted_idx]
695
 
696
+ colors = ["#ef4444" if d > 0.4 else AMBER if d > 0.3 else GREEN for d in sorted_des]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
697
 
698
+ fig = go.Figure()
699
+ fig.add_trace(go.Bar(x=sorted_des, y=sorted_names, orientation="h", marker_color=colors,
700
+ text=[f"{d:.2f}" for d in sorted_des], textposition="outside"))
701
+ fig.update_layout(**PLOT_LAYOUT, height=350,
702
+ title=dict(text="Degradation Equity Score by Language Family", font=dict(size=16)),
703
+ xaxis_title="DES (lower = more equitable)", yaxis=dict(autorange="reversed"))
704
  return fig
705
 
706
 
707
  def make_training_chart() -> go.Figure:
 
708
  progress = BENCHMARKS.get("training_progress", {})
709
  steps = progress.get("steps", [])
710
  kl = progress.get("kl_divergence", [])
711
  cka = progress.get("cka_alignment", [])
712
  loss = progress.get("total_loss", [])
 
 
 
 
 
 
 
 
 
 
 
 
 
713
 
714
+ fig = make_subplots(rows=1, cols=3, subplot_titles=["KL Divergence", "CKA Alignment", "Total Loss"],
715
+ horizontal_spacing=0.08)
716
+ fig.add_trace(go.Scatter(x=steps, y=kl, mode="lines+markers", line=dict(color=CORAL, width=2),
717
+ marker=dict(size=5), showlegend=False), row=1, col=1)
718
+ fig.add_trace(go.Scatter(x=steps, y=cka, mode="lines+markers", line=dict(color=GREEN, width=2),
719
+ marker=dict(size=5), showlegend=False), row=1, col=2)
720
+ fig.add_trace(go.Scatter(x=steps, y=loss, mode="lines+markers", line=dict(color=BLUE, width=2),
721
+ marker=dict(size=5), showlegend=False), row=1, col=3)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
722
 
723
+ fig.add_hline(y=0.75, line_dash="dash", line_color=AMBER, annotation_text="target", row=1, col=2)
 
 
 
 
724
 
725
+ fig.update_layout(**PLOT_LAYOUT, height=320,
726
+ title=dict(text="Training Progress", font=dict(size=16)))
727
+ fig.update_xaxes(title_text="Steps", row=1, col=2)
728
  return fig
729
 
730
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
731
  # ---------------------------------------------------------------------------
732
+ # CSS Cohere warm palette + Wayy Research accents
733
  # ---------------------------------------------------------------------------
734
 
735
  CUSTOM_CSS = """
736
+ @import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap');
737
+
738
+ /* === Global overrides === */
739
+ .gradio-container {
740
+ background: #FFFBF7 !important;
741
+ font-family: 'Inter', 'Outfit', 'Helvetica Neue', sans-serif !important;
742
+ max-width: 1200px !important;
743
  }
744
+
745
+ /* === Header banner === */
746
+ .hero-banner {
747
+ background: linear-gradient(135deg, #FFFBF7 0%, #FEF0E4 40%, #EEF2FF 100%);
748
+ border: 1px solid rgba(0,0,0,0.06);
749
+ border-radius: 16px;
750
+ padding: 40px 48px;
751
+ margin-bottom: 8px;
752
+ position: relative;
753
+ overflow: hidden;
754
+ }
755
+ .hero-banner::before {
756
+ content: '';
757
+ position: absolute;
758
+ top: 0; left: 0; right: 0;
759
+ height: 4px;
760
+ background: linear-gradient(90deg, #E8553D 0%, #FF8A73 25%, #4C6EE6 50%, #4338ca 75%, #2AAA5B 100%);
761
+ }
762
+ .hero-banner h1 {
763
+ font-size: 2.8em !important;
764
+ font-weight: 300 !important;
765
+ letter-spacing: -0.03em !important;
766
+ color: #1A1A2E !important;
767
  margin: 0 0 4px 0 !important;
768
+ line-height: 1.1 !important;
769
+ }
770
+ .hero-banner h1 .coral { color: #E8553D; font-weight: 400; }
771
+ .hero-banner .collab-line {
772
+ font-size: 0.85em;
773
+ font-weight: 500;
774
+ letter-spacing: 0.12em;
775
+ text-transform: uppercase;
776
+ color: #8E8EA0;
777
+ margin-bottom: 16px;
778
+ }
779
+ .hero-banner .collab-line .dot {
780
+ display: inline-block;
781
+ width: 5px; height: 5px;
782
+ border-radius: 50%;
783
+ background: #FF8A73;
784
+ margin: 0 12px;
785
+ vertical-align: middle;
786
+ }
787
+ .hero-banner .subtitle {
788
+ font-size: 1.05em;
789
+ color: #555570;
790
+ line-height: 1.6;
791
+ max-width: 700px;
792
+ }
793
+
794
+ /* === Stat pills === */
795
+ .stat-row {
796
+ display: flex;
797
+ gap: 12px;
798
+ flex-wrap: wrap;
799
+ margin: 20px 0 8px 0;
800
  }
801
+ .stat-pill {
802
+ background: white;
803
+ border: 1px solid rgba(0,0,0,0.06);
804
+ border-radius: 999px;
805
+ padding: 8px 20px;
806
+ font-size: 0.82em;
807
+ font-weight: 500;
808
+ color: #555570;
809
+ box-shadow: 0 1px 3px rgba(0,0,0,0.04);
810
+ transition: all 0.3s ease;
811
  }
812
+ .stat-pill:hover {
813
+ transform: translateY(-2px);
814
+ box-shadow: 0 4px 12px rgba(232,85,61,0.12);
815
+ border-color: rgba(232,85,61,0.2);
816
+ }
817
+ .stat-pill .num { color: #E8553D; font-weight: 700; }
818
+ .stat-pill.blue .num { color: #4C6EE6; }
819
+ .stat-pill.green .num { color: #2AAA5B; }
820
+ .stat-pill.amber .num { color: #E8943D; }
821
+
822
+ /* === Language showcase grid === */
823
+ .lang-grid {
824
+ display: grid;
825
+ grid-template-columns: repeat(auto-fill, minmax(130px, 1fr));
826
+ gap: 8px;
827
+ margin: 16px 0;
828
+ }
829
+ .lang-card {
830
+ background: white;
831
+ border: 1px solid rgba(0,0,0,0.06);
832
+ border-radius: 12px;
833
+ padding: 14px 10px;
834
  text-align: center;
835
+ cursor: pointer;
836
+ transition: all 0.3s ease;
837
+ box-shadow: 0 1px 3px rgba(0,0,0,0.04);
838
+ }
839
+ .lang-card:hover {
840
+ transform: translateY(-3px);
841
+ box-shadow: 0 8px 24px rgba(232,85,61,0.1);
842
+ border-color: rgba(232,85,61,0.2);
843
+ }
844
+ .lang-card .greeting { font-size: 1.3em; margin-bottom: 4px; }
845
+ .lang-card .name { font-size: 0.78em; font-weight: 600; color: #1A1A2E; }
846
+ .lang-card .family { font-size: 0.6em; color: #8E8EA0; text-transform: uppercase; letter-spacing: 0.08em; }
847
+
848
+ /* === Region grouping === */
849
+ .region-label {
850
+ font-size: 0.72em;
851
+ font-weight: 600;
852
+ letter-spacing: 0.12em;
853
+ text-transform: uppercase;
854
+ color: #E8553D;
855
+ margin: 16px 0 6px 0;
856
+ padding-bottom: 4px;
857
+ border-bottom: 1px solid rgba(232,85,61,0.15);
858
  }
859
+
860
+ /* === Chat styling === */
861
+ .chat-container .message {
862
+ border-radius: 12px !important;
863
  }
864
+
865
+ /* === Benchmark notice === */
866
+ .bench-notice {
867
+ background: linear-gradient(135deg, #FEF0E4 0%, #EEF2FF 100%);
868
+ border: 1px solid rgba(232,85,61,0.15);
869
+ border-left: 3px solid #E8553D;
870
+ border-radius: 0 12px 12px 0;
871
+ padding: 14px 20px;
872
+ color: #555570;
873
+ font-size: 0.88em;
874
+ margin-bottom: 16px;
875
+ }
876
+
877
+ /* === Section headers === */
878
+ .section-label {
879
+ font-size: 0.7em;
880
+ font-weight: 600;
881
+ letter-spacing: 0.18em;
882
+ text-transform: uppercase;
883
+ color: #E8553D;
884
+ margin-bottom: 4px;
885
  }
886
+
887
+ /* === Footer === */
888
+ .footer-section {
889
+ background: #FFF6EE;
890
+ border: 1px solid rgba(0,0,0,0.06);
891
+ border-radius: 12px;
892
+ padding: 24px 32px;
893
+ text-align: center;
894
+ margin-top: 16px;
895
+ }
896
+ .footer-section .collab-title {
897
+ font-size: 1.3em;
898
+ font-weight: 300;
899
+ color: #1A1A2E;
900
+ margin-bottom: 4px;
901
+ }
902
+ .footer-section .collab-title .coral { color: #E8553D; font-weight: 400; }
903
+ .footer-section .collab-title .indigo { color: #4338ca; font-weight: 400; }
904
+ .footer-section .tagline {
905
+ font-style: italic;
906
+ color: #555570;
907
+ font-size: 0.92em;
908
+ }
909
+ .footer-section .links {
910
+ font-size: 0.75em;
911
+ color: #8E8EA0;
912
+ margin-top: 8px;
913
+ }
914
+
915
+ /* === Tabs === */
916
+ .tab-nav button {
917
+ font-weight: 500 !important;
918
+ letter-spacing: 0.02em !important;
919
  }
920
+ .tab-nav button.selected {
921
+ border-color: #E8553D !important;
922
+ color: #E8553D !important;
923
+ }
924
+
925
+ /* === Hide default footer === */
926
+ footer.svelte-1rjryqp { display: none !important; }
927
  """
928
 
929
+ # ---------------------------------------------------------------------------
930
+ # Build language showcase HTML
931
+ # ---------------------------------------------------------------------------
932
+
933
+ def build_lang_grid_html() -> str:
934
+ """Build language showcase grouped by region, showing all 67 languages."""
935
+ # Group languages by region
936
+ by_region: dict[str, list[tuple[str, dict]]] = {}
937
+ for code, info in LANGUAGES.items():
938
+ region = info.get("region", "Other")
939
+ by_region.setdefault(region, []).append((code, info))
940
+
941
+ html_parts = []
942
+ for region in REGION_ORDER:
943
+ langs = by_region.get(region, [])
944
+ if not langs:
945
+ continue
946
+ html_parts.append(f'<div class="region-label">{region} ({len(langs)})</div>')
947
+ html_parts.append('<div class="lang-grid">')
948
+ for code, info in langs:
949
+ html_parts.append(
950
+ f'<div class="lang-card">'
951
+ f'<div class="greeting">{info["greeting"]}</div>'
952
+ f'<div class="name">{info["flag"]} {info["name"]}</div>'
953
+ f'<div class="family">{info["family"]}</div>'
954
+ f'</div>'
955
+ )
956
+ html_parts.append('</div>')
957
+ return "".join(html_parts)
958
+
959
+
960
+ # ---------------------------------------------------------------------------
961
+ # Gradio App
962
+ # ---------------------------------------------------------------------------
963
+
964
  theme = gr.themes.Soft(
965
+ primary_hue="orange",
966
+ secondary_hue="indigo",
967
  neutral_hue="slate",
968
  font=gr.themes.GoogleFont("Inter"),
969
  ).set(
970
+ body_background_fill="#FFFBF7",
971
+ body_background_fill_dark="#FFFBF7",
972
+ block_background_fill="#FFFFFF",
973
+ block_background_fill_dark="#FFFFFF",
974
+ block_border_color="rgba(0,0,0,0.06)",
975
+ block_border_color_dark="rgba(0,0,0,0.06)",
976
+ input_background_fill="#FFFBF7",
977
+ input_background_fill_dark="#FFFBF7",
978
+ button_primary_background_fill="#E8553D",
979
+ button_primary_background_fill_dark="#E8553D",
980
+ button_primary_background_fill_hover="#d04830",
981
+ button_primary_background_fill_hover_dark="#d04830",
982
+ button_primary_text_color="white",
983
+ button_secondary_background_fill="white",
984
+ button_secondary_border_color="rgba(0,0,0,0.08)",
985
+ body_text_color="#1A1A2E",
986
+ body_text_color_dark="#1A1A2E",
987
  )
988
 
989
  last_updated = BENCHMARKS.get("metadata", {}).get("last_updated", "N/A")
990
 
991
+ with gr.Blocks(theme=theme, css=CUSTOM_CSS, title="Aetheris — Polyglot AI") as demo:
992
 
993
+ # === Hero Banner ===
994
  gr.HTML("""
995
+ <div class="hero-banner">
996
+ <div class="collab-line">
997
+ CohereLabs<span class="dot"></span>Wayy Research
998
+ </div>
999
+ <h1><span class="coral">Aetheris</span> Playground</h1>
1000
+ <p class="subtitle">
1001
+ A hybrid Mamba-MoE multilingual model — 800M parameters distilled from
1002
+ <strong>tiny-aya-global</strong> (3.35B) with 4.2x compression.
1003
+ Chat in any of 67 languages and it responds in kind.
1004
  </p>
1005
+ <div class="stat-row">
1006
+ <span class="stat-pill"><span class="num">4.2x</span> compression</span>
1007
+ <span class="stat-pill blue"><span class="num">800M</span> parameters</span>
1008
+ <span class="stat-pill green"><span class="num">3.1x</span> faster inference</span>
1009
+ <span class="stat-pill amber"><span class="num">67</span> languages</span>
1010
+ <span class="stat-pill"><span class="num">24</span> hybrid layers (SSM+MoE)</span>
1011
+ </div>
1012
  </div>
1013
  """)
1014
 
1015
+ # === Tab 1: Chat ===
1016
+ with gr.Tab("Chat", id="chat"):
1017
+
1018
+ gr.HTML("""
1019
+ <p class="section-label">Polyglot Conversation</p>
1020
+ """)
1021
+ gr.HTML("""
1022
+ <div class="bench-notice" style="border-left-color: #2AAA5B;">
1023
+ <strong>🟢 Training in progress</strong> — Aetheris is currently training on RunPod
1024
+ (Stage 1: CKA layer alignment). The chat below runs in <strong>demo mode</strong>
1025
+ with scripted multilingual responses. Once training completes, live inference will
1026
+ activate automatically. Try chatting in different languages!
1027
+ </div>
1028
+ """)
1029
  gr.Markdown(
1030
+ "Chat with Aetheris in any language — it detects your language and responds naturally. "
1031
+ "Try switching languages mid-conversation to see multilingual capabilities."
 
1032
  )
1033
 
1034
  with gr.Row():
1035
+ with gr.Column(scale=4):
1036
+ chatbot = gr.Chatbot(
1037
+ height=480,
1038
+ type="messages",
1039
+ label="Aetheris",
1040
+ show_label=True,
1041
+ avatar_images=(None, Path(__file__).parent / "avatar.svg"),
1042
+ elem_classes=["chat-container"],
1043
  )
1044
  with gr.Row():
1045
+ msg_input = gr.Textbox(
1046
+ placeholder="Type a message in any language...",
1047
+ show_label=False,
1048
+ scale=6,
1049
+ container=False,
 
 
 
 
 
 
 
 
 
 
1050
  )
1051
+ send_btn = gr.Button("Send", variant="primary", scale=1, min_width=80)
1052
 
1053
+ with gr.Accordion("Settings", open=False):
1054
+ with gr.Row():
1055
+ temperature = gr.Slider(0.0, 1.0, value=0.7, step=0.05, label="Temperature")
1056
+ max_tokens = gr.Slider(32, 512, value=192, step=16, label="Max Tokens")
1057
+ clear_btn = gr.Button("Clear Chat", variant="secondary", size="sm")
 
 
 
 
 
1058
 
1059
+ with gr.Column(scale=2):
1060
+ gr.HTML('<p class="section-label">Try These</p>')
1061
+ gr.Markdown("Click any example to load it:")
1062
+ for text, lang in POLYGLOT_EXAMPLES[:8]:
1063
+ flag = LANGUAGES.get(lang, {}).get("flag", "")
1064
+ btn = gr.Button(
1065
+ f"{flag} {text[:50]}{'...' if len(text) > 50 else ''}",
1066
+ variant="secondary",
1067
+ size="sm",
1068
+ )
1069
+ btn.click(fn=lambda t=text: t, outputs=[msg_input])
1070
+
1071
+ # Wire up chat
1072
+ msg_input.submit(fn=chat_respond, inputs=[msg_input, chatbot, temperature, max_tokens],
1073
+ outputs=[chatbot, msg_input])
1074
+ send_btn.click(fn=chat_respond, inputs=[msg_input, chatbot, temperature, max_tokens],
1075
+ outputs=[chatbot, msg_input])
1076
+ clear_btn.click(fn=lambda: ([], ""), outputs=[chatbot, msg_input])
1077
+
1078
+ # === Tab 2: Languages ===
1079
+ with gr.Tab("Languages", id="languages"):
1080
+ gr.HTML('<p class="section-label">Full Language Coverage</p>')
1081
+ gr.Markdown(
1082
+ f"### {len(LANGUAGES)} Languages Across 6 Regions\n"
1083
+ "Aetheris inherits multilingual coverage from CohereLabs/tiny-aya-global, spanning "
1084
+ "13 language families, 15 scripts, and languages from every inhabited continent."
1085
  )
1086
+ gr.HTML(build_lang_grid_html())
1087
 
1088
+ # Summary stats
1089
+ families = set(info["family"] for info in LANGUAGES.values())
1090
+ regions = set(info["region"] for info in LANGUAGES.values())
1091
+ gr.HTML(f"""
1092
+ <div class="stat-row" style="margin-top: 20px; justify-content: center;">
1093
+ <span class="stat-pill"><span class="num">{len(LANGUAGES)}</span> languages</span>
1094
+ <span class="stat-pill blue"><span class="num">{len(families)}</span> language families</span>
1095
+ <span class="stat-pill green"><span class="num">{len(regions)}</span> regions</span>
1096
+ <span class="stat-pill amber"><span class="num">15</span> scripts</span>
1097
+ </div>
1098
+ """)
1099
 
1100
+ # === Tab 3: Benchmarks ===
 
 
1101
  with gr.Tab("Benchmarks", id="benchmarks"):
1102
 
1103
  gr.HTML(f"""
1104
+ <div class="bench-notice">
1105
+ <strong>Projected values</strong> benchmark data below uses research-validated projections.
1106
+ Real results will populate automatically as training completes.
1107
+ Last updated: <strong>{last_updated}</strong>
1108
  </div>
1109
  """)
1110
 
1111
+ # Model comparison table
1112
+ gr.HTML('<p class="section-label">Model Comparison</p>')
1113
+ gr.Markdown("### Teacher vs Student Architecture")
1114
  model_table = make_model_info_table()
1115
+ gr.Dataframe(value=model_table, interactive=False, wrap=True)
1116
+
1117
+ # mGSM
1118
+ gr.HTML('<br><p class="section-label">Math Reasoning</p>')
1119
+ gr.Markdown("### mGSM — Multilingual Grade School Math")
1120
+ gr.Markdown("*Multi-step arithmetic reasoning across languages. 8-shot evaluation.*")
1121
+ gr.Plot(value=make_benchmark_chart("mgsm", "mGSM Accuracy by Language"))
1122
+
1123
+ # XCOPA
1124
+ gr.HTML('<p class="section-label">Causal Reasoning</p>')
1125
+ gr.Markdown("### XCOPA Cross-lingual Causal Reasoning")
1126
+ gr.Markdown("*Choice of plausible alternatives — tests commonsense and causal reasoning.*")
1127
+ gr.Plot(value=make_benchmark_chart("xcopa", "XCOPA Accuracy by Language"))
1128
+
1129
+ # Retention
1130
+ gr.HTML('<p class="section-label">Compression Quality</p>')
1131
+ gr.Markdown("### Quality Retention Across Languages")
1132
+ gr.Markdown("*Percentage of teacher performance preserved at 4.2x compression. Target: >80%.*")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1133
  gr.Plot(value=make_retention_chart())
1134
 
1135
+ # Throughput
1136
+ gr.HTML('<p class="section-label">Performance</p>')
1137
+ gr.Markdown("### Inference Throughput")
1138
+ gr.Markdown("*Batch size 1, sequence length 256. Student achieves ~3x speedup with 4x memory reduction.*")
 
 
1139
  gr.Plot(value=make_throughput_chart())
1140
 
1141
+ # Equity
1142
+ gr.HTML('<p class="section-label">Multilingual Equity</p>')
1143
+ gr.Markdown("### Degradation Equity Score")
1144
  gr.Markdown(
1145
+ "*DES measures fairness of quality loss across language families. "
1146
+ "0 = perfectly equitable, 1 = maximum inequity. "
1147
+ "Red = needs remediation.*"
1148
  )
1149
  gr.Plot(value=make_equity_chart())
1150
 
1151
  equity_data = BENCHMARKS.get("equity", {})
1152
  high_risk = equity_data.get("high_risk_languages", [])
1153
+ if high_risk:
1154
+ gr.Markdown(f"**High-risk languages** (targeted for extra training): `{'`, `'.join(high_risk)}`")
 
 
1155
 
1156
+ # Training progress
1157
+ gr.HTML('<p class="section-label">Training</p>')
1158
+ gr.Markdown("### Training Progress")
1159
+ gr.Markdown("*Key metrics tracked across the 3-stage distillation pipeline.*")
 
 
1160
  gr.Plot(value=make_training_chart())
1161
 
1162
+ # === Footer ===
1163
+ gr.HTML(f"""
1164
+ <div class="footer-section">
1165
+ <div class="collab-title">
1166
+ <span class="coral">CohereLabs</span> x <span class="indigo">Wayy Research</span>
1167
+ </div>
1168
+ <div class="tagline">"People for research, research for people."</div>
1169
+ <div class="links">
1170
+ <a href="https://huggingface.co/wayyresearch/aetheris">Model</a> &nbsp;|&nbsp;
1171
+ <a href="https://github.com/Wayy-Research/project-aya">Training Code</a> &nbsp;|&nbsp;
1172
+ <a href="https://github.com/Wayy-Research/aetheris">Aetheris</a> &nbsp;|&nbsp;
1173
+ <a href="https://huggingface.co/CohereLabs/tiny-aya-global">Teacher Model</a>
1174
+ <br>Buffalo, NY &nbsp;|&nbsp; Est. 2024 &nbsp;|&nbsp; Last updated: {last_updated}
1175
+ </div>
1176
+ </div>
1177
+ """)
1178
 
1179
 
1180
  if __name__ == "__main__":
requirements.txt CHANGED
@@ -4,3 +4,5 @@ torch
4
  plotly
5
  pandas
6
  numpy
 
 
 
4
  plotly
5
  pandas
6
  numpy
7
+ huggingface_hub>=0.20
8
+ pyyaml