pentanet-124m / scripts /pentanet_analysis.py
Kyworn's picture
Upload folder using huggingface_hub
80603a0 verified
#!/usr/bin/env python3
"""
PentaNet Phase 1 — Proof of Concept (Offline)
===============================================
Compare la perte d'information entre :
- Binaire {-1, +1} (1.0 bit)
- Ternaire {-1, 0, +1} (1.58 bit) — BitNet b1.58
- Pentanaire {-2, -1, 0, +1, +2} (2.32 bit) — PentaNet
- INT4 [-7, +7] (4.0 bit)
Sur des poids simulés réalistes (distribution gaussienne d'un transformer).
"""
import torch
import numpy as np
import json
import os
torch.manual_seed(42)
np.random.seed(42)
# ============================================================
# 1. GÉNÉRATION DE POIDS RÉALISTES
# ============================================================
print("=" * 60)
print("🧬 PentaNet Phase 1 — Analyse de Quantification")
print("=" * 60)
# Simuler les poids d'un modèle 124M params (GPT-2 scale)
# Distribution : Gaussienne N(0, σ) avec σ variant selon le type de couche
total_params = 124_000_000
print(f"\n📥 Génération de poids simulés (distribution transformer réaliste)...")
print(f" Total paramètres : {total_params:,} ({total_params/1e6:.1f}M)")
# Les poids d'un transformer réel :
# - Attention : σ ≈ 0.02
# - FFN/MLP : σ ≈ 0.02-0.04
# - LayerNorm: σ ≈ 0.1 (autour de 1.0)
# - Embeddings: σ ≈ 0.02
attention_params = int(total_params * 0.33) # ~33% attention
ffn_params = int(total_params * 0.50) # ~50% FFN
layernorm_params = int(total_params * 0.01) # ~1% LayerNorm
embedding_params = total_params - attention_params - ffn_params - layernorm_params
layer_configs = {
"attention": (attention_params, 0.02),
"mlp/ffn": (ffn_params, 0.03),
"layernorm": (layernorm_params, 0.1),
"embedding": (embedding_params, 0.02),
}
all_weights_parts = []
layer_weights = {}
for name, (n, sigma) in layer_configs.items():
w = torch.randn(n) * sigma
all_weights_parts.append(w)
layer_weights[name] = w
print(f" {name:>12}: {n:>12,} params, σ={sigma}")
all_weights = torch.cat(all_weights_parts)
print(f" {'TOTAL':>12}: {len(all_weights):>12,} params")
# ============================================================
# 2. STATISTIQUES DE LA DISTRIBUTION DES POIDS
# ============================================================
print("\n📊 Distribution des poids originaux (FP32) :")
mean = all_weights.mean().item()
std = all_weights.std().item()
abs_mean = all_weights.abs().mean().item()
min_w = all_weights.min().item()
max_w = all_weights.max().item()
print(f" Moyenne : {mean:.6f}")
print(f" Écart-type : {std:.6f}")
print(f" |Moyenne| : {abs_mean:.6f}")
print(f" Min / Max : [{min_w:.4f}, {max_w:.4f}]")
# ============================================================
# 3. FONCTIONS DE QUANTIFICATION
# ============================================================
def quantize_ternary(weights, block_size=64):
"""Ternaire {-1, 0, +1} style BitNet b1.58 — absmean scaling."""
n = len(weights)
quantized = torch.zeros_like(weights)
scales = []
for i in range(0, n, block_size):
block = weights[i:i+block_size]
scale = block.abs().mean()
if scale == 0:
scales.append(0.0)
continue
normalized = block / scale
q = torch.clamp(torch.round(normalized), -1, 1)
quantized[i:i+block_size] = q * scale
scales.append(scale.item())
return quantized, scales
def quantize_pentanary(weights, block_size=64):
"""Pentanaire {-2, -1, 0, +1, +2} — PentaNet, absmean scaling ×2."""
n = len(weights)
quantized = torch.zeros_like(weights)
scales = []
for i in range(0, n, block_size):
block = weights[i:i+block_size]
# Scale : absmean / 1 pour que la majorité tombe dans [-2, +2]
scale = block.abs().mean()
if scale == 0:
scales.append(0.0)
continue
normalized = block / scale
q = torch.clamp(torch.round(normalized), -2, 2)
quantized[i:i+block_size] = q * scale
scales.append(scale.item())
return quantized, scales
def quantize_pentanary_maxscale(weights, block_size=64):
"""Pentanaire {-2, -1, 0, +1, +2} — scaling par max/2."""
n = len(weights)
quantized = torch.zeros_like(weights)
scales = []
for i in range(0, n, block_size):
block = weights[i:i+block_size]
scale = block.abs().max() / 2.0
if scale == 0:
scales.append(0.0)
continue
normalized = block / scale
q = torch.clamp(torch.round(normalized), -2, 2)
quantized[i:i+block_size] = q * scale
scales.append(scale.item())
return quantized, scales
def quantize_binary(weights, block_size=64):
"""Binaire {-1, +1} — baseline 1-bit."""
n = len(weights)
quantized = torch.zeros_like(weights)
scales = []
for i in range(0, n, block_size):
block = weights[i:i+block_size]
scale = block.abs().mean()
if scale == 0:
scales.append(0.0)
continue
q = torch.sign(block)
q[q == 0] = 1
quantized[i:i+block_size] = q * scale
scales.append(scale.item())
return quantized, scales
def quantize_int4(weights, block_size=64):
"""INT4 symétrique [-7, +7] — baseline 4-bit."""
n = len(weights)
quantized = torch.zeros_like(weights)
scales = []
for i in range(0, n, block_size):
block = weights[i:i+block_size]
scale = block.abs().max() / 7.0
if scale == 0:
scales.append(0.0)
continue
normalized = block / scale
q = torch.clamp(torch.round(normalized), -7, 7)
quantized[i:i+block_size] = q * scale
scales.append(scale.item())
return quantized, scales
# ============================================================
# 4. EXÉCUTION DES QUANTIFICATIONS
# ============================================================
print("\n⚙️ Quantification en cours...")
methods = {
"Binaire {-1,+1} (1.0 bit)": quantize_binary,
"Ternaire {-1,0,+1} (1.58 bit) — BitNet": quantize_ternary,
"PentaNet-absmean (2.32 bit)": quantize_pentanary,
"PentaNet-maxscale (2.32 bit)": quantize_pentanary_maxscale,
"INT4 [-7,+7] (4.0 bit)": quantize_int4,
}
results = {}
for name, func in methods.items():
q_weights, scales = func(all_weights.clone())
mse = ((all_weights - q_weights) ** 2).mean().item()
rmse = mse ** 0.5
cos_sim = torch.nn.functional.cosine_similarity(
all_weights.unsqueeze(0), q_weights.unsqueeze(0)
).item()
signal_power = (all_weights ** 2).mean().item()
noise_power = mse
snr_db = 10 * np.log10(signal_power / noise_power) if noise_power > 0 else float('inf')
# Relative error
rel_err = (((all_weights - q_weights).abs()) / (all_weights.abs() + 1e-10)).mean().item()
results[name] = {
"mse": mse,
"rmse": rmse,
"cosine_similarity": cos_sim,
"snr_db": snr_db,
"relative_error": rel_err,
}
print(f"\n 📐 {name}")
print(f" MSE : {mse:.10f}")
print(f" RMSE : {rmse:.10f}")
print(f" Cosine Sim : {cos_sim:.8f}")
print(f" SNR : {snr_db:.2f} dB")
print(f" Err. relative : {rel_err*100:.2f}%")
# ============================================================
# 5. COMPARAISON DIRECTE
# ============================================================
print("\n" + "=" * 60)
print("📊 TABLEAU COMPARATIF")
print("=" * 60)
print(f"\n {'Méthode':<35} {'MSE':>12} {'Cosine':>10} {'SNR(dB)':>8} {'Err%':>8}")
print(f" {'─'*35} {'─'*12} {'─'*10} {'─'*8} {'─'*8}")
for name, r in results.items():
short = name.split("(")[0].strip()
print(f" {short:<35} {r['mse']:>12.10f} {r['cosine_similarity']:>10.8f} {r['snr_db']:>8.2f} {r['relative_error']*100:>7.2f}%")
# ============================================================
# 6. PENTANET VS BITNET
# ============================================================
print("\n" + "=" * 60)
print("🔬 PENTANET vs BITNET — Analyse détaillée")
print("=" * 60)
# Utiliser le meilleur scaling pour PentaNet
penta_absmean = results["PentaNet-absmean (2.32 bit)"]
penta_maxscale = results["PentaNet-maxscale (2.32 bit)"]
ternary = results["Ternaire {-1,0,+1} (1.58 bit) — BitNet"]
int4 = results["INT4 [-7,+7] (4.0 bit)"]
# Choisir le meilleur PentaNet
if penta_absmean["mse"] < penta_maxscale["mse"]:
penta = penta_absmean
penta_best = "absmean"
else:
penta = penta_maxscale
penta_best = "maxscale"
print(f"\n Meilleur scaling PentaNet : {penta_best}")
mse_improvement = (1 - penta["mse"] / ternary["mse"]) * 100
snr_improvement = penta["snr_db"] - ternary["snr_db"]
cos_improvement = penta["cosine_similarity"] - ternary["cosine_similarity"]
print(f"\n PentaNet vs BitNet Ternaire :")
print(f" ├── MSE réduite de : {mse_improvement:+.1f}%")
print(f" ├── SNR supérieur de : {snr_improvement:+.2f} dB")
print(f" ├── Cosine Sim gain : {cos_improvement:+.10f}")
print(f" └── Erreur relative PentaNet : {penta['relative_error']*100:.2f}% (vs {ternary['relative_error']*100:.2f}% BitNet)")
print(f"\n Coût mémoire (pour 124M params) :")
print(f" ├── BitNet : ~{total_params * 1.58 / 8 / 1e6:.1f} Mo (1.58 bit/param)")
print(f" ├── PentaNet : ~{total_params * 2.66 / 8 / 1e6:.1f} Mo (2.66 bit/param) → +{((2.66/1.58)-1)*100:.0f}% mémoire")
print(f" └── INT4 : ~{total_params * 4 / 8 / 1e6:.1f} Mo (4.0 bit/param)")
print(f"\n Ratio qualité/mémoire :")
mem_penta = 2.66
mem_tern = 1.58
mem_int4 = 4.0
quality_penta = penta["snr_db"]
quality_tern = ternary["snr_db"]
quality_int4 = int4["snr_db"]
print(f" ├── BitNet : {quality_tern/mem_tern:.2f} dB/bit")
print(f" ├── PentaNet : {quality_penta/mem_penta:.2f} dB/bit")
print(f" └── INT4 : {quality_int4/mem_int4:.2f} dB/bit")
# ============================================================
# 7. ANALYSE PAR TYPE DE COUCHE
# ============================================================
print("\n" + "=" * 60)
print("🔬 ANALYSE PAR TYPE DE COUCHE")
print("=" * 60)
for cat, w in layer_weights.items():
q_tern, _ = quantize_ternary(w.clone())
q_penta, _ = quantize_pentanary(w.clone())
mse_t = ((w - q_tern) ** 2).mean().item()
mse_p = ((w - q_penta) ** 2).mean().item()
improvement = (1 - mse_p / mse_t) * 100 if mse_t > 0 else 0
cos_t = torch.nn.functional.cosine_similarity(w.unsqueeze(0), q_tern.unsqueeze(0)).item()
cos_p = torch.nn.functional.cosine_similarity(w.unsqueeze(0), q_penta.unsqueeze(0)).item()
print(f"\n {cat.upper()} ({len(w):,} params, σ={w.std():.4f})")
print(f" ├── BitNet MSE: {mse_t:.10f} cos: {cos_t:.8f}")
print(f" ├── PentaNet MSE: {mse_p:.10f} cos: {cos_p:.8f}")
print(f" └── Gain PentaNet : {improvement:+.1f}% MSE")
# ============================================================
# 8. DISTRIBUTION DES VALEURS PENTANAIRES
# ============================================================
print("\n" + "=" * 60)
print("📊 DISTRIBUTION DES VALEURS PENTANAIRES")
print("=" * 60)
block_size = 64
value_counts = {-2: 0, -1: 0, 0: 0, 1: 0, 2: 0}
for i in range(0, len(all_weights), block_size):
block = all_weights[i:i+block_size]
scale = block.abs().mean()
if scale == 0:
value_counts[0] += len(block)
continue
normalized = block / scale
q = torch.clamp(torch.round(normalized), -2, 2).int()
for v in [-2, -1, 0, 1, 2]:
value_counts[v] += (q == v).sum().item()
total_q = sum(value_counts.values())
print(f"\n Valeur | Nombre | Pourcentage | Barre")
print(f" -------|----------------|-------------|------")
for v in [-2, -1, 0, 1, 2]:
count = value_counts[v]
pct = count / total_q * 100
bar = "█" * int(pct)
label = {-2: "-2", -1: "-1", 0: " 0", 1: "+1", 2: "+2"}[v]
print(f" {label} | {count:>14,} | {pct:>9.1f}% | {bar}")
# Calcul de l'entropie effective
probs = np.array([value_counts[v] / total_q for v in [-2, -1, 0, 1, 2]])
probs = probs[probs > 0]
entropy = -np.sum(probs * np.log2(probs))
print(f"\n Entropie effective : {entropy:.3f} bits (max théorique: {np.log2(5):.3f} bits)")
print(f" Efficacité : {entropy / np.log2(5) * 100:.1f}%")
# ============================================================
# 9. VERDICT
# ============================================================
print("\n" + "=" * 60)
print("🏆 VERDICT")
print("=" * 60)
if mse_improvement > 40:
verdict = "TRÈS PROMETTEUR"
emoji = "🔥"
detail = f"PentaNet réduit la MSE de {mse_improvement:.1f}% vs BitNet. Le gain justifie le surcoût mémoire de 68%."
elif mse_improvement > 20:
verdict = "PROMETTEUR"
emoji = "✅"
detail = f"PentaNet réduit la MSE de {mse_improvement:.1f}% vs BitNet. À valider avec un entraînement natif."
elif mse_improvement > 5:
verdict = "MARGINAL"
emoji = "⚠️"
detail = f"PentaNet réduit la MSE de seulement {mse_improvement:.1f}% vs BitNet. Le surcoût mémoire de 68% est difficilement justifiable."
else:
verdict = "NON CONCLUANT"
emoji = "❌"
detail = f"Gain de {mse_improvement:.1f}%. Les 2 états supplémentaires n'apportent pas assez vs BitNet."
print(f"\n {emoji} {verdict}")
print(f" {detail}")
print(f"\n ⚠️ RAPPELS IMPORTANTS :")
print(f" 1. Ceci est une analyse PTQ sur des poids SIMULÉS (pas un vrai modèle)")
print(f" 2. BitNet b1.58 perd en PTQ mais récupère en entraînement natif")
print(f" 3. Le vrai test = entraîner un modèle nativement pentanaire")
print(f" 4. Le ratio qualité/bit est plus important que la MSE brute")
# Sauvegarder
output = {
"model": "Simulated Transformer 124M (Gaussian weights)",
"total_params": total_params,
"results": {k: v for k, v in results.items()},
"pentanet_vs_bitnet": {
"best_scaling": penta_best,
"mse_reduction_pct": mse_improvement,
"snr_gain_db": snr_improvement,
"memory_overhead_pct": ((2.66/1.58)-1)*100,
},
"pentanary_distribution": {str(k): v for k, v in value_counts.items()},
"entropy_bits": entropy,
"verdict": verdict,
}
output_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "pentanet_results.json")
with open(output_path, "w") as f:
json.dump(output, f, indent=2)
print(f"\n 💾 Résultats sauvegardés dans pentanet_results.json")
print("=" * 60)