Cygnis-Alpha / app.py
Simonc-44's picture
Update app.py
b84db60 verified
import gradio as gr
import torch
import torch.nn as nn
from torch.nn import functional as F
from safetensors.torch import load_file
import json
import os
# --- 1. CHARGEMENT DE LA CONFIGURATION ---
with open("config.json", "r", encoding="utf-8") as f:
config = json.load(f)
# Paramètres extraits du JSON
n_embd = config["n_embd"]
n_head = config["n_head"]
n_layer = config["n_layer"]
block_size = config["block_size"]
vocab_size = config["vocab_size"]
stoi = config["stoi"]
itos = {int(k): v for k, v in config["itos"].items()}
# --- 2. ARCHITECTURE DU MODÈLE (Version Alignée Simon Chusseau Edition) ---
class SelfAttention(nn.Module):
def __init__(self, n_embd, n_head):
super().__init__()
# Utilisation de qkv_proj pour correspondre aux poids entraînés
self.qkv_proj = nn.Linear(n_embd, 3 * n_embd, bias=False)
self.out_proj = nn.Linear(n_embd, n_embd, bias=False)
self.n_head = n_head
def forward(self, x):
B, T, C = x.size()
q, k, v = self.qkv_proj(x).split(n_embd, dim=2)
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
att = (q @ k.transpose(-2, -1)) * (1.0 / (C // self.n_head)**0.5)
mask = torch.triu(torch.ones(T, T, device=x.device) * float('-inf'), 1)
att = F.softmax(att + mask[:T, :T], dim=-1)
y = att @ v
y = y.transpose(1, 2).contiguous().view(B, T, C)
return self.out_proj(y)
class Block(nn.Module):
def __init__(self, n_embd, n_head):
super().__init__()
self.sa = SelfAttention(n_embd, n_head)
self.ffwd = nn.Sequential(
nn.Linear(n_embd, 4 * n_embd), nn.GELU(),
nn.Linear(4 * n_embd, n_embd), nn.Dropout(0.1)
)
self.ln1, self.ln2 = nn.LayerNorm(n_embd), nn.LayerNorm(n_embd)
def forward(self, x):
x = x + self.sa(self.ln1(x))
x = x + self.ffwd(self.ln2(x))
return x
class CygnisAlpha(nn.Module):
def __init__(self):
super().__init__()
self.token_embedding = nn.Embedding(vocab_size, n_embd)
self.position_embedding = nn.Embedding(block_size, n_embd)
self.blocks = nn.Sequential(*[Block(n_embd, n_head) for _ in range(n_layer)])
self.ln_f = nn.LayerNorm(n_embd)
self.lm_head = nn.Linear(n_embd, vocab_size)
def forward(self, idx):
B, T = idx.shape
tok_emb = self.token_embedding(idx)
pos_emb = self.position_embedding(torch.arange(T, device=idx.device))
x = self.blocks(tok_emb + pos_emb)
logits = self.lm_head(self.ln_f(x))
return logits
# --- 3. INITIALISATION ---
device = "cuda" if torch.cuda.is_available() else "cpu"
model = CygnisAlpha().to(device)
# Nom du fichier final fusionné
model_path = "alpha_cycle_8.safetensors"
if os.path.exists(model_path):
state_dict = load_file(model_path)
model.load_state_dict(state_dict)
model.eval()
print(f"✅ Modèle chargé : {model_path}")
else:
print(f"⚠️ Erreur : {model_path} non trouvé dans le répertoire.")
# --- 4. LOGIQUE DE GÉNÉRATION STABILISÉE ---
def generate_response(message, history, temperature=0.4, max_tokens=150):
# Encodage sécurisé
input_ids = torch.tensor([stoi.get(c, stoi.get(' ', 0)) for c in message], dtype=torch.long, device=device).unsqueeze(0)
generated = input_ids
for _ in range(max_tokens):
cond = generated[:, -block_size:]
with torch.no_grad():
logits = model(cond)
logits = logits[:, -1, :] / max(temperature, 0.01)
# Filtre anti-répétition immédiate
if generated.size(1) >= 2:
if generated[0, -1] == generated[0, -2]:
logits[0, generated[0, -1]] -= 15.0
# Pénalité de fréquence (évite les boucles infinies)
response_so_far = generated[0, len(input_ids[0]):]
if response_so_far.numel() > 0:
char_counts = torch.bincount(response_so_far, minlength=vocab_size)
logits[0] -= (char_counts * 0.8)
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
generated = torch.cat((generated, next_token), dim=1)
# Arrêt si point final après une longueur minimale
char = itos.get(next_token.item(), '')
if char == "." and len(generated[0]) > len(input_ids[0]) + 20:
break
# Décodage uniquement de la réponse
full_text = "".join([itos.get(i.item(), '') for i in generated[0, len(input_ids[0]):]])
return full_text.strip()
# --- 5. INTERFACE GRADIO ---
demo = gr.ChatInterface(
fn=generate_response,
title="🌌 Cygnis Alpha v1.0",
description="Identité scellée : Simon Chusseau. Architecture 162M.",
examples=[
["Qui est ton créateur ?", 0.3, 100],
["Explique la singularité technologique.", 0.6, 200],
["Qui es-tu ?", 0.4, 100]
],
additional_inputs=[
gr.Slider(0.1, 1.2, value=0.4, label="Température (Stable < 0.5)"),
gr.Slider(50, 500, value=150, step=10, label="Tokens Max")
]
)
if __name__ == "__main__":
demo.launch()