sclm-chat-demo / app.py
amewebstudio's picture
Upload app.py with huggingface_hub
705ad7b verified
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import json
from typing import List, Optional, Dict, Any, Tuple
from dataclasses import dataclass, field
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from huggingface_hub import hf_hub_download
print("="*50)
print("SCLM Option B - Chargement...")
print("="*50)
# ============================================================
# CONFIGURATION
# ============================================================
HF_TOKEN = os.environ.get("HF_TOKEN")
if not HF_TOKEN:
raise ValueError("HF_TOKEN secret not found!")
# MODÈLE INSTRUCT (conversation)
BASE_MODEL = "meta-llama/Llama-3.2-3B-Instruct"
SCLM_REPO = "amewebstudio/sclm-modelEarcp-optionB"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
# ============================================================
# SCLM CLASSES
# ============================================================
@dataclass
class SCLMConfigB:
vocab_size: int = 128256
hidden_size: int = 3072
num_hidden_layers: int = 28
num_attention_heads: int = 24
latent_state_dim: int = 384
n_experts: int = 3
n_coherence_heads: int = 8
expert_intermediate: int = 1536
state_injection_layers: List[int] = field(default_factory=lambda: [4, 8, 12, 16, 20, 24])
alpha_inject: float = 0.03
use_state_in_attention: bool = True
use_state_in_ffn: bool = True
state_fusion_method: str = "concat"
class StateFFNInjector(nn.Module):
def __init__(self, hidden_size: int, state_dim: int, intermediate_size: int):
super().__init__()
self.state_proj = nn.Linear(state_dim, intermediate_size)
self.output_proj = nn.Linear(intermediate_size, hidden_size)
self.gate = nn.Linear(hidden_size, 1)
nn.init.zeros_(self.output_proj.weight)
def forward(self, hidden, state, alpha=0.03):
state_proj = F.silu(self.state_proj(state))
state_output = self.output_proj(state_proj)
gate = torch.sigmoid(self.gate(hidden.mean(dim=1, keepdim=True)))
return hidden + alpha * gate * state_output.unsqueeze(1)
class EncapsulationB(nn.Module):
def __init__(self, hidden_size: int, state_dim: int):
super().__init__()
self.state_dim = state_dim
self.n_pool_heads = 4
self.pool_proj = nn.Linear(hidden_size, state_dim * self.n_pool_heads)
self.pool_combine = nn.Linear(state_dim * self.n_pool_heads, state_dim)
self.update_gate = nn.Linear(state_dim * 2, state_dim)
self.reset_gate = nn.Linear(state_dim * 2, state_dim)
self.candidate = nn.Linear(state_dim * 2, state_dim)
self.attn_query = nn.Linear(state_dim, hidden_size)
def forward(self, hidden, state, attention_mask=None, edit_mode=False):
if edit_mode:
return state, {}
B, T, H = hidden.shape
query = self.attn_query(state)
attn_scores = torch.bmm(hidden, query.unsqueeze(-1)).squeeze(-1)
if attention_mask is not None:
attn_scores = attn_scores.masked_fill(attention_mask == 0, float("-inf"))
attn_weights = F.softmax(attn_scores, dim=-1)
h_pooled = torch.bmm(attn_weights.unsqueeze(1), hidden).squeeze(1)
h_proj = F.silu(self.pool_proj(h_pooled))
h_proj = self.pool_combine(h_proj)
combined = torch.cat([h_proj, state], dim=-1)
z = torch.sigmoid(self.update_gate(combined))
r = torch.sigmoid(self.reset_gate(combined))
h_cand = torch.tanh(self.candidate(torch.cat([h_proj, r * state], dim=-1)))
new_state = (1 - z) * state + z * h_cand
new_state = torch.tanh(new_state / 10.0) * 10.0
return new_state, {}
class CoherenceExpertsB(nn.Module):
def __init__(self, hidden_size: int, intermediate_size: int, n_experts: int = 3):
super().__init__()
self.n_experts = n_experts
self.experts = nn.ModuleList([
nn.Sequential(
nn.Linear(hidden_size, intermediate_size),
nn.SiLU(),
nn.Dropout(0.1),
nn.Linear(intermediate_size, hidden_size)
) for _ in range(n_experts)
])
self.router = nn.Sequential(
nn.Linear(hidden_size, 128),
nn.SiLU(),
nn.Linear(128, n_experts)
)
for exp in self.experts:
nn.init.zeros_(exp[-1].weight)
def forward(self, hidden, temperature=1.0):
router_logits = self.router(hidden.mean(dim=1)) / temperature
weights = F.softmax(router_logits, dim=-1)
expert_outputs = torch.stack([exp(hidden) for exp in self.experts], dim=0)
w = weights.T.unsqueeze(-1).unsqueeze(-1)
output = (w * expert_outputs).sum(dim=0)
return output, {}
class EARCPModuleB(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
H = config.hidden_size
S = config.latent_state_dim
self.ffn_injectors = nn.ModuleDict({
str(i): StateFFNInjector(H, S, config.expert_intermediate)
for i in config.state_injection_layers
})
self.encapsulation = EncapsulationB(H, S)
self.coherence = CoherenceExpertsB(H, config.expert_intermediate, config.n_experts)
def update_state(self, hidden, state, attention_mask=None, edit_mode=False):
new_state, _ = self.encapsulation(hidden, state, attention_mask, edit_mode)
hidden, _ = self.coherence(hidden)
return new_state, hidden, {}
class SCLMModelOptionB(nn.Module):
def __init__(self, config, base_model):
super().__init__()
self.config = config
self.base_model = base_model
self.model_device = next(base_model.parameters()).device
self.model_dtype = next(base_model.parameters()).dtype
self.earcp = EARCPModuleB(config).to(self.model_device).to(self.model_dtype)
self.latent_state = torch.zeros(1, config.latent_state_dim, device=self.model_device, dtype=self.model_dtype)
self.state_frozen = False
self.edit_mode = False
def reset_state(self):
self.latent_state = torch.zeros(1, self.config.latent_state_dim, device=self.model_device, dtype=self.model_dtype)
def forward(self, input_ids, attention_mask=None, **kwargs):
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
base_out = self.base_model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True, **kwargs)
hidden = base_out.hidden_states[-1]
B = hidden.size(0)
if next(self.earcp.encapsulation.parameters()).device != hidden.device:
self.earcp = self.earcp.to(hidden.device)
state = self.latent_state.to(hidden.device, hidden.dtype).expand(B, -1)
new_state, enhanced, _ = self.earcp.update_state(hidden, state, attention_mask, self.edit_mode)
if not self.state_frozen:
self.latent_state = new_state.mean(dim=0, keepdim=True).detach()
return {"logits": base_out.logits, "state": self.latent_state.clone()}
# ============================================================
# CHARGEMENT
# ============================================================
print("1. Chargement Llama-3.2-3B-Instruct...")
quant_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16)
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL, quantization_config=quant_config, device_map="auto", token=HF_TOKEN
)
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, token=HF_TOKEN)
# Fix token IDs
if isinstance(tokenizer.eos_token_id, list):
tokenizer.eos_token_id = tokenizer.eos_token_id[0]
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
if isinstance(base_model.config.eos_token_id, list):
base_model.config.eos_token_id = base_model.config.eos_token_id[0]
base_model.config.pad_token_id = base_model.config.eos_token_id
print("2. Creation SCLM...")
config = SCLMConfigB(
vocab_size=base_model.config.vocab_size,
hidden_size=base_model.config.hidden_size,
num_hidden_layers=base_model.config.num_hidden_layers,
num_attention_heads=base_model.config.num_attention_heads,
)
sclm_model = SCLMModelOptionB(config, base_model)
print("3. Chargement poids EARCP...")
try:
weights_path = hf_hub_download(repo_id=SCLM_REPO, filename="earcp_weights.pt", token=HF_TOKEN)
earcp_weights = torch.load(weights_path, map_location="cpu")
sclm_model.earcp.load_state_dict(earcp_weights, strict=False)
print("SCLM charge!")
USE_SCLM = True
except Exception as e:
print(f"EARCP non charge: {e}")
USE_SCLM = False
# ============================================================
# CHAT
# ============================================================
conversation_history = []
def chat(message, history, temperature, max_tokens):
global conversation_history
if not message.strip():
return ""
if temperature is None:
temperature = 0.7
if max_tokens is None:
max_tokens = 150
conversation_history.append({"role": "user", "content": message})
# Format Llama 3.2 Instruct
prompt = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nTu es un assistant IA utile.<|eot_id|>"
for msg in conversation_history:
role = msg["role"]
prompt += f"<|start_header_id|>{role}<|end_header_id|>\n\n{msg['content']}<|eot_id|>"
prompt += "<|start_header_id|>assistant<|end_header_id|}\n\n"
inputs = tokenizer(prompt, return_tensors="pt").to(base_model.device)
if USE_SCLM:
with torch.no_grad():
_ = sclm_model(inputs.input_ids, attention_mask=inputs.attention_mask)
eos_id = tokenizer.eos_token_id
if isinstance(eos_id, list):
eos_id = eos_id[0]
with torch.no_grad():
outputs = base_model.generate(
inputs.input_ids,
attention_mask=inputs.attention_mask,
max_new_tokens=int(max_tokens),
temperature=float(temperature),
do_sample=True,
top_p=0.9,
repetition_penalty=1.15,
pad_token_id=eos_id,
eos_token_id=eos_id,
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
if "assistant" in response.lower():
response = response.split("assistant")[-1].strip()
for tag in ["<|eot_id|>", "<|end_header_id|>", "<|start_header_id|>", ":"]:
response = response.replace(tag, "").strip()
if not response:
response = "..."
conversation_history.append({"role": "assistant", "content": response})
if len(conversation_history) > 10:
conversation_history = conversation_history[-10:]
return response
def reset():
global conversation_history
conversation_history = []
if USE_SCLM:
sclm_model.reset_state()
return [], "Reset!"
# ============================================================
# INTERFACE
# ============================================================
with gr.Blocks(title="SCLM Chat") as demo:
gr.Markdown("# SCLM - Stateful Coherent Language Model\n**sclm-modelEarcp-optionB** | Llama-3.2-3B-Instruct")
with gr.Row():
with gr.Column(scale=3):
chatbot = gr.Chatbot(height=400)
msg = gr.Textbox(label="Message", placeholder="Ecris ici...")
with gr.Row():
send = gr.Button("Envoyer", variant="primary")
clear = gr.Button("Reset")
with gr.Column(scale=1):
temp = gr.Slider(0.1, 1.5, value=0.7, label="Temperature")
tokens = gr.Slider(50, 300, value=150, label="Max tokens")
state = gr.Textbox(label="Info", value="SCLM actif" if USE_SCLM else "Base only")
def respond(message, history, t, m):
response = chat(message, history, t, m)
history.append((message, response))
info = f"State: {sclm_model.latent_state.norm().item():.2f}" if USE_SCLM else "Base"
return "", history, info
send.click(respond, [msg, chatbot, temp, tokens], [msg, chatbot, state])
msg.submit(respond, [msg, chatbot, temp, tokens], [msg, chatbot, state])
clear.click(reset, outputs=[chatbot, state])
demo.launch()