import gradio as gr import torch import torch.nn as nn import torch.nn.functional as F import os import json from typing import List, Optional, Dict, Any, Tuple from dataclasses import dataclass, field from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig from huggingface_hub import hf_hub_download print("="*50) print("SCLM Option B - Chargement...") print("="*50) # ============================================================ # CONFIGURATION # ============================================================ HF_TOKEN = os.environ.get("HF_TOKEN") if not HF_TOKEN: raise ValueError("HF_TOKEN secret not found!") # MODÈLE INSTRUCT (conversation) BASE_MODEL = "meta-llama/Llama-3.2-3B-Instruct" SCLM_REPO = "amewebstudio/sclm-modelEarcp-optionB" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Device: {device}") # ============================================================ # SCLM CLASSES # ============================================================ @dataclass class SCLMConfigB: vocab_size: int = 128256 hidden_size: int = 3072 num_hidden_layers: int = 28 num_attention_heads: int = 24 latent_state_dim: int = 384 n_experts: int = 3 n_coherence_heads: int = 8 expert_intermediate: int = 1536 state_injection_layers: List[int] = field(default_factory=lambda: [4, 8, 12, 16, 20, 24]) alpha_inject: float = 0.03 use_state_in_attention: bool = True use_state_in_ffn: bool = True state_fusion_method: str = "concat" class StateFFNInjector(nn.Module): def __init__(self, hidden_size: int, state_dim: int, intermediate_size: int): super().__init__() self.state_proj = nn.Linear(state_dim, intermediate_size) self.output_proj = nn.Linear(intermediate_size, hidden_size) self.gate = nn.Linear(hidden_size, 1) nn.init.zeros_(self.output_proj.weight) def forward(self, hidden, state, alpha=0.03): state_proj = F.silu(self.state_proj(state)) state_output = self.output_proj(state_proj) gate = torch.sigmoid(self.gate(hidden.mean(dim=1, keepdim=True))) return hidden + alpha * gate * state_output.unsqueeze(1) class EncapsulationB(nn.Module): def __init__(self, hidden_size: int, state_dim: int): super().__init__() self.state_dim = state_dim self.n_pool_heads = 4 self.pool_proj = nn.Linear(hidden_size, state_dim * self.n_pool_heads) self.pool_combine = nn.Linear(state_dim * self.n_pool_heads, state_dim) self.update_gate = nn.Linear(state_dim * 2, state_dim) self.reset_gate = nn.Linear(state_dim * 2, state_dim) self.candidate = nn.Linear(state_dim * 2, state_dim) self.attn_query = nn.Linear(state_dim, hidden_size) def forward(self, hidden, state, attention_mask=None, edit_mode=False): if edit_mode: return state, {} B, T, H = hidden.shape query = self.attn_query(state) attn_scores = torch.bmm(hidden, query.unsqueeze(-1)).squeeze(-1) if attention_mask is not None: attn_scores = attn_scores.masked_fill(attention_mask == 0, float("-inf")) attn_weights = F.softmax(attn_scores, dim=-1) h_pooled = torch.bmm(attn_weights.unsqueeze(1), hidden).squeeze(1) h_proj = F.silu(self.pool_proj(h_pooled)) h_proj = self.pool_combine(h_proj) combined = torch.cat([h_proj, state], dim=-1) z = torch.sigmoid(self.update_gate(combined)) r = torch.sigmoid(self.reset_gate(combined)) h_cand = torch.tanh(self.candidate(torch.cat([h_proj, r * state], dim=-1))) new_state = (1 - z) * state + z * h_cand new_state = torch.tanh(new_state / 10.0) * 10.0 return new_state, {} class CoherenceExpertsB(nn.Module): def __init__(self, hidden_size: int, intermediate_size: int, n_experts: int = 3): super().__init__() self.n_experts = n_experts self.experts = nn.ModuleList([ nn.Sequential( nn.Linear(hidden_size, intermediate_size), nn.SiLU(), nn.Dropout(0.1), nn.Linear(intermediate_size, hidden_size) ) for _ in range(n_experts) ]) self.router = nn.Sequential( nn.Linear(hidden_size, 128), nn.SiLU(), nn.Linear(128, n_experts) ) for exp in self.experts: nn.init.zeros_(exp[-1].weight) def forward(self, hidden, temperature=1.0): router_logits = self.router(hidden.mean(dim=1)) / temperature weights = F.softmax(router_logits, dim=-1) expert_outputs = torch.stack([exp(hidden) for exp in self.experts], dim=0) w = weights.T.unsqueeze(-1).unsqueeze(-1) output = (w * expert_outputs).sum(dim=0) return output, {} class EARCPModuleB(nn.Module): def __init__(self, config): super().__init__() self.config = config H = config.hidden_size S = config.latent_state_dim self.ffn_injectors = nn.ModuleDict({ str(i): StateFFNInjector(H, S, config.expert_intermediate) for i in config.state_injection_layers }) self.encapsulation = EncapsulationB(H, S) self.coherence = CoherenceExpertsB(H, config.expert_intermediate, config.n_experts) def update_state(self, hidden, state, attention_mask=None, edit_mode=False): new_state, _ = self.encapsulation(hidden, state, attention_mask, edit_mode) hidden, _ = self.coherence(hidden) return new_state, hidden, {} class SCLMModelOptionB(nn.Module): def __init__(self, config, base_model): super().__init__() self.config = config self.base_model = base_model self.model_device = next(base_model.parameters()).device self.model_dtype = next(base_model.parameters()).dtype self.earcp = EARCPModuleB(config).to(self.model_device).to(self.model_dtype) self.latent_state = torch.zeros(1, config.latent_state_dim, device=self.model_device, dtype=self.model_dtype) self.state_frozen = False self.edit_mode = False def reset_state(self): self.latent_state = torch.zeros(1, self.config.latent_state_dim, device=self.model_device, dtype=self.model_dtype) def forward(self, input_ids, attention_mask=None, **kwargs): if attention_mask is None: attention_mask = torch.ones_like(input_ids) base_out = self.base_model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True, **kwargs) hidden = base_out.hidden_states[-1] B = hidden.size(0) if next(self.earcp.encapsulation.parameters()).device != hidden.device: self.earcp = self.earcp.to(hidden.device) state = self.latent_state.to(hidden.device, hidden.dtype).expand(B, -1) new_state, enhanced, _ = self.earcp.update_state(hidden, state, attention_mask, self.edit_mode) if not self.state_frozen: self.latent_state = new_state.mean(dim=0, keepdim=True).detach() return {"logits": base_out.logits, "state": self.latent_state.clone()} # ============================================================ # CHARGEMENT # ============================================================ print("1. Chargement Llama-3.2-3B-Instruct...") quant_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16) base_model = AutoModelForCausalLM.from_pretrained( BASE_MODEL, quantization_config=quant_config, device_map="auto", token=HF_TOKEN ) tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, token=HF_TOKEN) # Fix token IDs if isinstance(tokenizer.eos_token_id, list): tokenizer.eos_token_id = tokenizer.eos_token_id[0] if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token_id = tokenizer.eos_token_id if isinstance(base_model.config.eos_token_id, list): base_model.config.eos_token_id = base_model.config.eos_token_id[0] base_model.config.pad_token_id = base_model.config.eos_token_id print("2. Creation SCLM...") config = SCLMConfigB( vocab_size=base_model.config.vocab_size, hidden_size=base_model.config.hidden_size, num_hidden_layers=base_model.config.num_hidden_layers, num_attention_heads=base_model.config.num_attention_heads, ) sclm_model = SCLMModelOptionB(config, base_model) print("3. Chargement poids EARCP...") try: weights_path = hf_hub_download(repo_id=SCLM_REPO, filename="earcp_weights.pt", token=HF_TOKEN) earcp_weights = torch.load(weights_path, map_location="cpu") sclm_model.earcp.load_state_dict(earcp_weights, strict=False) print("SCLM charge!") USE_SCLM = True except Exception as e: print(f"EARCP non charge: {e}") USE_SCLM = False # ============================================================ # CHAT # ============================================================ conversation_history = [] def chat(message, history, temperature, max_tokens): global conversation_history if not message.strip(): return "" if temperature is None: temperature = 0.7 if max_tokens is None: max_tokens = 150 conversation_history.append({"role": "user", "content": message}) # Format Llama 3.2 Instruct prompt = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nTu es un assistant IA utile.<|eot_id|>" for msg in conversation_history: role = msg["role"] prompt += f"<|start_header_id|>{role}<|end_header_id|>\n\n{msg['content']}<|eot_id|>" prompt += "<|start_header_id|>assistant<|end_header_id|}\n\n" inputs = tokenizer(prompt, return_tensors="pt").to(base_model.device) if USE_SCLM: with torch.no_grad(): _ = sclm_model(inputs.input_ids, attention_mask=inputs.attention_mask) eos_id = tokenizer.eos_token_id if isinstance(eos_id, list): eos_id = eos_id[0] with torch.no_grad(): outputs = base_model.generate( inputs.input_ids, attention_mask=inputs.attention_mask, max_new_tokens=int(max_tokens), temperature=float(temperature), do_sample=True, top_p=0.9, repetition_penalty=1.15, pad_token_id=eos_id, eos_token_id=eos_id, ) response = tokenizer.decode(outputs[0], skip_special_tokens=True) if "assistant" in response.lower(): response = response.split("assistant")[-1].strip() for tag in ["<|eot_id|>", "<|end_header_id|>", "<|start_header_id|>", ":"]: response = response.replace(tag, "").strip() if not response: response = "..." conversation_history.append({"role": "assistant", "content": response}) if len(conversation_history) > 10: conversation_history = conversation_history[-10:] return response def reset(): global conversation_history conversation_history = [] if USE_SCLM: sclm_model.reset_state() return [], "Reset!" # ============================================================ # INTERFACE # ============================================================ with gr.Blocks(title="SCLM Chat") as demo: gr.Markdown("# SCLM - Stateful Coherent Language Model\n**sclm-modelEarcp-optionB** | Llama-3.2-3B-Instruct") with gr.Row(): with gr.Column(scale=3): chatbot = gr.Chatbot(height=400) msg = gr.Textbox(label="Message", placeholder="Ecris ici...") with gr.Row(): send = gr.Button("Envoyer", variant="primary") clear = gr.Button("Reset") with gr.Column(scale=1): temp = gr.Slider(0.1, 1.5, value=0.7, label="Temperature") tokens = gr.Slider(50, 300, value=150, label="Max tokens") state = gr.Textbox(label="Info", value="SCLM actif" if USE_SCLM else "Base only") def respond(message, history, t, m): response = chat(message, history, t, m) history.append((message, response)) info = f"State: {sclm_model.latent_state.norm().item():.2f}" if USE_SCLM else "Base" return "", history, info send.click(respond, [msg, chatbot, temp, tokens], [msg, chatbot, state]) msg.submit(respond, [msg, chatbot, temp, tokens], [msg, chatbot, state]) clear.click(reset, outputs=[chatbot, state]) demo.launch()