Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import os | |
| import json | |
| from typing import List, Optional, Dict, Any, Tuple | |
| from dataclasses import dataclass, field | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
| from huggingface_hub import hf_hub_download | |
| print("="*50) | |
| print("SCLM Option B - Chargement...") | |
| print("="*50) | |
| # ============================================================ | |
| # CONFIGURATION | |
| # ============================================================ | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| if not HF_TOKEN: | |
| raise ValueError("HF_TOKEN secret not found!") | |
| # MODÈLE INSTRUCT (conversation) | |
| BASE_MODEL = "meta-llama/Llama-3.2-3B-Instruct" | |
| SCLM_REPO = "amewebstudio/sclm-modelEarcp-optionB" | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Device: {device}") | |
| # ============================================================ | |
| # SCLM CLASSES | |
| # ============================================================ | |
| class SCLMConfigB: | |
| vocab_size: int = 128256 | |
| hidden_size: int = 3072 | |
| num_hidden_layers: int = 28 | |
| num_attention_heads: int = 24 | |
| latent_state_dim: int = 384 | |
| n_experts: int = 3 | |
| n_coherence_heads: int = 8 | |
| expert_intermediate: int = 1536 | |
| state_injection_layers: List[int] = field(default_factory=lambda: [4, 8, 12, 16, 20, 24]) | |
| alpha_inject: float = 0.03 | |
| use_state_in_attention: bool = True | |
| use_state_in_ffn: bool = True | |
| state_fusion_method: str = "concat" | |
| class StateFFNInjector(nn.Module): | |
| def __init__(self, hidden_size: int, state_dim: int, intermediate_size: int): | |
| super().__init__() | |
| self.state_proj = nn.Linear(state_dim, intermediate_size) | |
| self.output_proj = nn.Linear(intermediate_size, hidden_size) | |
| self.gate = nn.Linear(hidden_size, 1) | |
| nn.init.zeros_(self.output_proj.weight) | |
| def forward(self, hidden, state, alpha=0.03): | |
| state_proj = F.silu(self.state_proj(state)) | |
| state_output = self.output_proj(state_proj) | |
| gate = torch.sigmoid(self.gate(hidden.mean(dim=1, keepdim=True))) | |
| return hidden + alpha * gate * state_output.unsqueeze(1) | |
| class EncapsulationB(nn.Module): | |
| def __init__(self, hidden_size: int, state_dim: int): | |
| super().__init__() | |
| self.state_dim = state_dim | |
| self.n_pool_heads = 4 | |
| self.pool_proj = nn.Linear(hidden_size, state_dim * self.n_pool_heads) | |
| self.pool_combine = nn.Linear(state_dim * self.n_pool_heads, state_dim) | |
| self.update_gate = nn.Linear(state_dim * 2, state_dim) | |
| self.reset_gate = nn.Linear(state_dim * 2, state_dim) | |
| self.candidate = nn.Linear(state_dim * 2, state_dim) | |
| self.attn_query = nn.Linear(state_dim, hidden_size) | |
| def forward(self, hidden, state, attention_mask=None, edit_mode=False): | |
| if edit_mode: | |
| return state, {} | |
| B, T, H = hidden.shape | |
| query = self.attn_query(state) | |
| attn_scores = torch.bmm(hidden, query.unsqueeze(-1)).squeeze(-1) | |
| if attention_mask is not None: | |
| attn_scores = attn_scores.masked_fill(attention_mask == 0, float("-inf")) | |
| attn_weights = F.softmax(attn_scores, dim=-1) | |
| h_pooled = torch.bmm(attn_weights.unsqueeze(1), hidden).squeeze(1) | |
| h_proj = F.silu(self.pool_proj(h_pooled)) | |
| h_proj = self.pool_combine(h_proj) | |
| combined = torch.cat([h_proj, state], dim=-1) | |
| z = torch.sigmoid(self.update_gate(combined)) | |
| r = torch.sigmoid(self.reset_gate(combined)) | |
| h_cand = torch.tanh(self.candidate(torch.cat([h_proj, r * state], dim=-1))) | |
| new_state = (1 - z) * state + z * h_cand | |
| new_state = torch.tanh(new_state / 10.0) * 10.0 | |
| return new_state, {} | |
| class CoherenceExpertsB(nn.Module): | |
| def __init__(self, hidden_size: int, intermediate_size: int, n_experts: int = 3): | |
| super().__init__() | |
| self.n_experts = n_experts | |
| self.experts = nn.ModuleList([ | |
| nn.Sequential( | |
| nn.Linear(hidden_size, intermediate_size), | |
| nn.SiLU(), | |
| nn.Dropout(0.1), | |
| nn.Linear(intermediate_size, hidden_size) | |
| ) for _ in range(n_experts) | |
| ]) | |
| self.router = nn.Sequential( | |
| nn.Linear(hidden_size, 128), | |
| nn.SiLU(), | |
| nn.Linear(128, n_experts) | |
| ) | |
| for exp in self.experts: | |
| nn.init.zeros_(exp[-1].weight) | |
| def forward(self, hidden, temperature=1.0): | |
| router_logits = self.router(hidden.mean(dim=1)) / temperature | |
| weights = F.softmax(router_logits, dim=-1) | |
| expert_outputs = torch.stack([exp(hidden) for exp in self.experts], dim=0) | |
| w = weights.T.unsqueeze(-1).unsqueeze(-1) | |
| output = (w * expert_outputs).sum(dim=0) | |
| return output, {} | |
| class EARCPModuleB(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.config = config | |
| H = config.hidden_size | |
| S = config.latent_state_dim | |
| self.ffn_injectors = nn.ModuleDict({ | |
| str(i): StateFFNInjector(H, S, config.expert_intermediate) | |
| for i in config.state_injection_layers | |
| }) | |
| self.encapsulation = EncapsulationB(H, S) | |
| self.coherence = CoherenceExpertsB(H, config.expert_intermediate, config.n_experts) | |
| def update_state(self, hidden, state, attention_mask=None, edit_mode=False): | |
| new_state, _ = self.encapsulation(hidden, state, attention_mask, edit_mode) | |
| hidden, _ = self.coherence(hidden) | |
| return new_state, hidden, {} | |
| class SCLMModelOptionB(nn.Module): | |
| def __init__(self, config, base_model): | |
| super().__init__() | |
| self.config = config | |
| self.base_model = base_model | |
| self.model_device = next(base_model.parameters()).device | |
| self.model_dtype = next(base_model.parameters()).dtype | |
| self.earcp = EARCPModuleB(config).to(self.model_device).to(self.model_dtype) | |
| self.latent_state = torch.zeros(1, config.latent_state_dim, device=self.model_device, dtype=self.model_dtype) | |
| self.state_frozen = False | |
| self.edit_mode = False | |
| def reset_state(self): | |
| self.latent_state = torch.zeros(1, self.config.latent_state_dim, device=self.model_device, dtype=self.model_dtype) | |
| def forward(self, input_ids, attention_mask=None, **kwargs): | |
| if attention_mask is None: | |
| attention_mask = torch.ones_like(input_ids) | |
| base_out = self.base_model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True, **kwargs) | |
| hidden = base_out.hidden_states[-1] | |
| B = hidden.size(0) | |
| if next(self.earcp.encapsulation.parameters()).device != hidden.device: | |
| self.earcp = self.earcp.to(hidden.device) | |
| state = self.latent_state.to(hidden.device, hidden.dtype).expand(B, -1) | |
| new_state, enhanced, _ = self.earcp.update_state(hidden, state, attention_mask, self.edit_mode) | |
| if not self.state_frozen: | |
| self.latent_state = new_state.mean(dim=0, keepdim=True).detach() | |
| return {"logits": base_out.logits, "state": self.latent_state.clone()} | |
| # ============================================================ | |
| # CHARGEMENT | |
| # ============================================================ | |
| print("1. Chargement Llama-3.2-3B-Instruct...") | |
| quant_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16) | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL, quantization_config=quant_config, device_map="auto", token=HF_TOKEN | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, token=HF_TOKEN) | |
| # Fix token IDs | |
| if isinstance(tokenizer.eos_token_id, list): | |
| tokenizer.eos_token_id = tokenizer.eos_token_id[0] | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| tokenizer.pad_token_id = tokenizer.eos_token_id | |
| if isinstance(base_model.config.eos_token_id, list): | |
| base_model.config.eos_token_id = base_model.config.eos_token_id[0] | |
| base_model.config.pad_token_id = base_model.config.eos_token_id | |
| print("2. Creation SCLM...") | |
| config = SCLMConfigB( | |
| vocab_size=base_model.config.vocab_size, | |
| hidden_size=base_model.config.hidden_size, | |
| num_hidden_layers=base_model.config.num_hidden_layers, | |
| num_attention_heads=base_model.config.num_attention_heads, | |
| ) | |
| sclm_model = SCLMModelOptionB(config, base_model) | |
| print("3. Chargement poids EARCP...") | |
| try: | |
| weights_path = hf_hub_download(repo_id=SCLM_REPO, filename="earcp_weights.pt", token=HF_TOKEN) | |
| earcp_weights = torch.load(weights_path, map_location="cpu") | |
| sclm_model.earcp.load_state_dict(earcp_weights, strict=False) | |
| print("SCLM charge!") | |
| USE_SCLM = True | |
| except Exception as e: | |
| print(f"EARCP non charge: {e}") | |
| USE_SCLM = False | |
| # ============================================================ | |
| # CHAT | |
| # ============================================================ | |
| conversation_history = [] | |
| def chat(message, history, temperature, max_tokens): | |
| global conversation_history | |
| if not message.strip(): | |
| return "" | |
| if temperature is None: | |
| temperature = 0.7 | |
| if max_tokens is None: | |
| max_tokens = 150 | |
| conversation_history.append({"role": "user", "content": message}) | |
| # Format Llama 3.2 Instruct | |
| prompt = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nTu es un assistant IA utile.<|eot_id|>" | |
| for msg in conversation_history: | |
| role = msg["role"] | |
| prompt += f"<|start_header_id|>{role}<|end_header_id|>\n\n{msg['content']}<|eot_id|>" | |
| prompt += "<|start_header_id|>assistant<|end_header_id|}\n\n" | |
| inputs = tokenizer(prompt, return_tensors="pt").to(base_model.device) | |
| if USE_SCLM: | |
| with torch.no_grad(): | |
| _ = sclm_model(inputs.input_ids, attention_mask=inputs.attention_mask) | |
| eos_id = tokenizer.eos_token_id | |
| if isinstance(eos_id, list): | |
| eos_id = eos_id[0] | |
| with torch.no_grad(): | |
| outputs = base_model.generate( | |
| inputs.input_ids, | |
| attention_mask=inputs.attention_mask, | |
| max_new_tokens=int(max_tokens), | |
| temperature=float(temperature), | |
| do_sample=True, | |
| top_p=0.9, | |
| repetition_penalty=1.15, | |
| pad_token_id=eos_id, | |
| eos_token_id=eos_id, | |
| ) | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| if "assistant" in response.lower(): | |
| response = response.split("assistant")[-1].strip() | |
| for tag in ["<|eot_id|>", "<|end_header_id|>", "<|start_header_id|>", ":"]: | |
| response = response.replace(tag, "").strip() | |
| if not response: | |
| response = "..." | |
| conversation_history.append({"role": "assistant", "content": response}) | |
| if len(conversation_history) > 10: | |
| conversation_history = conversation_history[-10:] | |
| return response | |
| def reset(): | |
| global conversation_history | |
| conversation_history = [] | |
| if USE_SCLM: | |
| sclm_model.reset_state() | |
| return [], "Reset!" | |
| # ============================================================ | |
| # INTERFACE | |
| # ============================================================ | |
| with gr.Blocks(title="SCLM Chat") as demo: | |
| gr.Markdown("# SCLM - Stateful Coherent Language Model\n**sclm-modelEarcp-optionB** | Llama-3.2-3B-Instruct") | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| chatbot = gr.Chatbot(height=400) | |
| msg = gr.Textbox(label="Message", placeholder="Ecris ici...") | |
| with gr.Row(): | |
| send = gr.Button("Envoyer", variant="primary") | |
| clear = gr.Button("Reset") | |
| with gr.Column(scale=1): | |
| temp = gr.Slider(0.1, 1.5, value=0.7, label="Temperature") | |
| tokens = gr.Slider(50, 300, value=150, label="Max tokens") | |
| state = gr.Textbox(label="Info", value="SCLM actif" if USE_SCLM else "Base only") | |
| def respond(message, history, t, m): | |
| response = chat(message, history, t, m) | |
| history.append((message, response)) | |
| info = f"State: {sclm_model.latent_state.norm().item():.2f}" if USE_SCLM else "Base" | |
| return "", history, info | |
| send.click(respond, [msg, chatbot, temp, tokens], [msg, chatbot, state]) | |
| msg.submit(respond, [msg, chatbot, temp, tokens], [msg, chatbot, state]) | |
| clear.click(reset, outputs=[chatbot, state]) | |
| demo.launch() | |