import gradio as gr import torch import torch.nn as nn import torch.nn.functional as F import os from typing import List from dataclasses import dataclass, field from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig from huggingface_hub import hf_hub_download print("Ananke - Chargement...") HF_TOKEN = os.environ.get("HF_TOKEN") if not HF_TOKEN: raise ValueError("HF_TOKEN not found in secrets") BASE_MODEL = "meta-llama/Llama-3.2-3B-Instruct" SCLM_REPO = "amewebstudio/sclm-modelEarcp-optionB" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") SYSTEM_PROMPT = """Tu es Ananke, un assistant IA developpe par Mike Amega (Logo) de Ame Web Studio. TON NOM: Ananke | TON CREATEUR: Mike Amega (Logo) | TON MODELE: Ananke Tu sais: repondre aux questions, aider en redaction, expliquer des concepts, programmer, maintenir une conversation coherente. Architecture SCLM: memoire latente 384 dimensions, module EARCP, 3 experts specialises. Style: chaleureux, utile, complet. Reponds dans la langue de l utilisateur.""" @dataclass class SCLMConfigB: vocab_size: int = 128256 hidden_size: int = 3072 num_hidden_layers: int = 28 num_attention_heads: int = 24 latent_state_dim: int = 384 n_experts: int = 3 expert_intermediate: int = 1536 state_injection_layers: List[int] = field(default_factory=lambda: [4, 8, 12, 16, 20, 24]) class EncapsulationB(nn.Module): def __init__(self, hidden_size, state_dim): super().__init__() self.pool_proj = nn.Linear(hidden_size, state_dim * 4) self.pool_combine = nn.Linear(state_dim * 4, state_dim) self.update_gate = nn.Linear(state_dim * 2, state_dim) self.reset_gate = nn.Linear(state_dim * 2, state_dim) self.candidate = nn.Linear(state_dim * 2, state_dim) self.attn_query = nn.Linear(state_dim, hidden_size) def forward(self, hidden, state, mask=None): B, T, H = hidden.shape query = self.attn_query(state) scores = torch.bmm(hidden, query.unsqueeze(-1)).squeeze(-1) if mask is not None: scores = scores.masked_fill(mask == 0, -1e9) weights = F.softmax(scores, dim=-1) pooled = torch.bmm(weights.unsqueeze(1), hidden).squeeze(1) proj = F.silu(self.pool_proj(pooled)) proj = self.pool_combine(proj) combined = torch.cat([proj, state], dim=-1) z = torch.sigmoid(self.update_gate(combined)) r = torch.sigmoid(self.reset_gate(combined)) cand = torch.tanh(self.candidate(torch.cat([proj, r * state], dim=-1))) new_state = (1 - z) * state + z * cand return torch.tanh(new_state / 10.0) * 10.0 class CoherenceExperts(nn.Module): def __init__(self, hidden_size, intermediate, n_experts=3): super().__init__() self.experts = nn.ModuleList([ nn.Sequential(nn.Linear(hidden_size, intermediate), nn.SiLU(), nn.Linear(intermediate, hidden_size)) for _ in range(n_experts) ]) self.router = nn.Sequential(nn.Linear(hidden_size, 64), nn.SiLU(), nn.Linear(64, n_experts)) def forward(self, hidden): logits = self.router(hidden.mean(dim=1)) weights = F.softmax(logits, dim=-1) outputs = torch.stack([e(hidden) for e in self.experts], dim=0) w = weights.T.unsqueeze(-1).unsqueeze(-1) return (w * outputs).sum(dim=0) class EARCPModule(nn.Module): def __init__(self, config): super().__init__() self.encapsulation = EncapsulationB(config.hidden_size, config.latent_state_dim) self.coherence = CoherenceExperts(config.hidden_size, config.expert_intermediate, config.n_experts) class SCLMModel(nn.Module): def __init__(self, config, base): super().__init__() self.config = config self.base_model = base dev = next(base.parameters()).device dtype = next(base.parameters()).dtype self.earcp = EARCPModule(config).to(dev).to(dtype) self.state = torch.zeros(1, config.latent_state_dim, device=dev, dtype=dtype) def reset(self): dev = next(self.base_model.parameters()).device dtype = next(self.base_model.parameters()).dtype self.state = torch.zeros(1, self.config.latent_state_dim, device=dev, dtype=dtype) def forward(self, ids, mask=None): if mask is None: mask = torch.ones_like(ids) out = self.base_model(input_ids=ids, attention_mask=mask, output_hidden_states=True) hidden = out.hidden_states[-1] B = hidden.size(0) state = self.state.to(hidden.device, hidden.dtype).expand(B, -1) new_state = self.earcp.encapsulation(hidden, state, mask) self.state = new_state.mean(dim=0, keepdim=True).detach() return out.logits print("1. Loading base model...") qconfig = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16) base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, quantization_config=qconfig, device_map="auto", token=HF_TOKEN) tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, token=HF_TOKEN) if isinstance(tokenizer.eos_token_id, list): tokenizer.eos_token_id = tokenizer.eos_token_id[0] if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token_id = tokenizer.eos_token_id if isinstance(base_model.config.eos_token_id, list): base_model.config.eos_token_id = base_model.config.eos_token_id[0] base_model.config.pad_token_id = base_model.config.eos_token_id print("2. Creating SCLM...") config = SCLMConfigB( vocab_size=base_model.config.vocab_size, hidden_size=base_model.config.hidden_size, num_hidden_layers=base_model.config.num_hidden_layers, num_attention_heads=base_model.config.num_attention_heads, ) sclm = SCLMModel(config, base_model) print("3. Loading EARCP weights...") try: wpath = hf_hub_download(repo_id=SCLM_REPO, filename="earcp_weights.pt", token=HF_TOKEN) sclm.earcp.load_state_dict(torch.load(wpath, map_location="cpu"), strict=False) USE_SCLM = True print("EARCP loaded!") except: USE_SCLM = False print("Ananke ready!") history = [] def chat(message, temp=0.7, max_tok=1024): global history if not message.strip(): return "" history.append(("user", message)) prompt = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n" + SYSTEM_PROMPT + "<|eot_id|>" for role, content in history[-10:]: prompt += f"<|start_header_id|>{role}<|end_header_id|>\n\n{content}<|eot_id|>" prompt += "<|start_header_id|>assistant<|end_header_id|>\n\n" inputs = tokenizer(prompt, return_tensors="pt").to(base_model.device) if USE_SCLM: with torch.no_grad(): sclm(inputs.input_ids, inputs.attention_mask) eos = tokenizer.eos_token_id with torch.no_grad(): out = base_model.generate( inputs.input_ids, attention_mask=inputs.attention_mask, max_new_tokens=int(max_tok), temperature=float(temp), do_sample=True, top_p=0.9, repetition_penalty=1.1, pad_token_id=eos, eos_token_id=eos, ) resp = tokenizer.decode(out[0], skip_special_tokens=True) if "assistant" in resp.lower(): resp = resp.split("assistant")[-1] for t in ["<|eot_id|>", "<|end_header_id|>", "<|start_header_id|>", "user", "system", ":"]: resp = resp.replace(t, "") resp = resp.strip() or "..." history.append(("assistant", resp)) return resp def clear(): global history history = [] if USE_SCLM: sclm.reset() return "" with gr.Blocks() as demo: gr.Markdown("# 🔮 Ananké\nAssistant IA avec mémoire contextuelle | Architecture SCLM par Mike Amega") with gr.Row(): with gr.Column(): output = gr.Textbox(label="Réponse", lines=15) inp = gr.Textbox(label="Message", lines=2, placeholder="Parle avec Ananké...") with gr.Column(): temp = gr.Slider(0.1, 1.5, 0.7, label="Créativité") tokens = gr.Slider(256, 2048, 1024, label="Longueur max") btn = gr.Button("Envoyer", variant="primary") clr = gr.Button("Effacer") btn.click(chat, [inp, temp, tokens], output) inp.submit(chat, [inp, temp, tokens], output) clr.click(clear, outputs=output) demo.queue().launch()