Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import os | |
| from typing import List | |
| from dataclasses import dataclass, field | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
| from huggingface_hub import hf_hub_download | |
| print("Ananke - Chargement...") | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| if not HF_TOKEN: | |
| raise ValueError("HF_TOKEN not found in secrets") | |
| BASE_MODEL = "meta-llama/Llama-3.2-3B-Instruct" | |
| SCLM_REPO = "amewebstudio/sclm-modelEarcp-optionB" | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| SYSTEM_PROMPT = """Tu es Ananke, un assistant IA developpe par Mike Amega (Logo) de Ame Web Studio. | |
| TON NOM: Ananke | TON CREATEUR: Mike Amega (Logo) | TON MODELE: Ananke | |
| Tu sais: repondre aux questions, aider en redaction, expliquer des concepts, programmer, maintenir une conversation coherente. | |
| Architecture SCLM: memoire latente 384 dimensions, module EARCP, 3 experts specialises. | |
| Style: chaleureux, utile, complet. Reponds dans la langue de l utilisateur.""" | |
| class SCLMConfigB: | |
| vocab_size: int = 128256 | |
| hidden_size: int = 3072 | |
| num_hidden_layers: int = 28 | |
| num_attention_heads: int = 24 | |
| latent_state_dim: int = 384 | |
| n_experts: int = 3 | |
| expert_intermediate: int = 1536 | |
| state_injection_layers: List[int] = field(default_factory=lambda: [4, 8, 12, 16, 20, 24]) | |
| class EncapsulationB(nn.Module): | |
| def __init__(self, hidden_size, state_dim): | |
| super().__init__() | |
| self.pool_proj = nn.Linear(hidden_size, state_dim * 4) | |
| self.pool_combine = nn.Linear(state_dim * 4, state_dim) | |
| self.update_gate = nn.Linear(state_dim * 2, state_dim) | |
| self.reset_gate = nn.Linear(state_dim * 2, state_dim) | |
| self.candidate = nn.Linear(state_dim * 2, state_dim) | |
| self.attn_query = nn.Linear(state_dim, hidden_size) | |
| def forward(self, hidden, state, mask=None): | |
| B, T, H = hidden.shape | |
| query = self.attn_query(state) | |
| scores = torch.bmm(hidden, query.unsqueeze(-1)).squeeze(-1) | |
| if mask is not None: | |
| scores = scores.masked_fill(mask == 0, -1e9) | |
| weights = F.softmax(scores, dim=-1) | |
| pooled = torch.bmm(weights.unsqueeze(1), hidden).squeeze(1) | |
| proj = F.silu(self.pool_proj(pooled)) | |
| proj = self.pool_combine(proj) | |
| combined = torch.cat([proj, state], dim=-1) | |
| z = torch.sigmoid(self.update_gate(combined)) | |
| r = torch.sigmoid(self.reset_gate(combined)) | |
| cand = torch.tanh(self.candidate(torch.cat([proj, r * state], dim=-1))) | |
| new_state = (1 - z) * state + z * cand | |
| return torch.tanh(new_state / 10.0) * 10.0 | |
| class CoherenceExperts(nn.Module): | |
| def __init__(self, hidden_size, intermediate, n_experts=3): | |
| super().__init__() | |
| self.experts = nn.ModuleList([ | |
| nn.Sequential(nn.Linear(hidden_size, intermediate), nn.SiLU(), nn.Linear(intermediate, hidden_size)) | |
| for _ in range(n_experts) | |
| ]) | |
| self.router = nn.Sequential(nn.Linear(hidden_size, 64), nn.SiLU(), nn.Linear(64, n_experts)) | |
| def forward(self, hidden): | |
| logits = self.router(hidden.mean(dim=1)) | |
| weights = F.softmax(logits, dim=-1) | |
| outputs = torch.stack([e(hidden) for e in self.experts], dim=0) | |
| w = weights.T.unsqueeze(-1).unsqueeze(-1) | |
| return (w * outputs).sum(dim=0) | |
| class EARCPModule(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.encapsulation = EncapsulationB(config.hidden_size, config.latent_state_dim) | |
| self.coherence = CoherenceExperts(config.hidden_size, config.expert_intermediate, config.n_experts) | |
| class SCLMModel(nn.Module): | |
| def __init__(self, config, base): | |
| super().__init__() | |
| self.config = config | |
| self.base_model = base | |
| dev = next(base.parameters()).device | |
| dtype = next(base.parameters()).dtype | |
| self.earcp = EARCPModule(config).to(dev).to(dtype) | |
| self.state = torch.zeros(1, config.latent_state_dim, device=dev, dtype=dtype) | |
| def reset(self): | |
| dev = next(self.base_model.parameters()).device | |
| dtype = next(self.base_model.parameters()).dtype | |
| self.state = torch.zeros(1, self.config.latent_state_dim, device=dev, dtype=dtype) | |
| def forward(self, ids, mask=None): | |
| if mask is None: | |
| mask = torch.ones_like(ids) | |
| out = self.base_model(input_ids=ids, attention_mask=mask, output_hidden_states=True) | |
| hidden = out.hidden_states[-1] | |
| B = hidden.size(0) | |
| state = self.state.to(hidden.device, hidden.dtype).expand(B, -1) | |
| new_state = self.earcp.encapsulation(hidden, state, mask) | |
| self.state = new_state.mean(dim=0, keepdim=True).detach() | |
| return out.logits | |
| print("1. Loading base model...") | |
| qconfig = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16) | |
| base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, quantization_config=qconfig, device_map="auto", token=HF_TOKEN) | |
| tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, token=HF_TOKEN) | |
| if isinstance(tokenizer.eos_token_id, list): | |
| tokenizer.eos_token_id = tokenizer.eos_token_id[0] | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| tokenizer.pad_token_id = tokenizer.eos_token_id | |
| if isinstance(base_model.config.eos_token_id, list): | |
| base_model.config.eos_token_id = base_model.config.eos_token_id[0] | |
| base_model.config.pad_token_id = base_model.config.eos_token_id | |
| print("2. Creating SCLM...") | |
| config = SCLMConfigB( | |
| vocab_size=base_model.config.vocab_size, | |
| hidden_size=base_model.config.hidden_size, | |
| num_hidden_layers=base_model.config.num_hidden_layers, | |
| num_attention_heads=base_model.config.num_attention_heads, | |
| ) | |
| sclm = SCLMModel(config, base_model) | |
| print("3. Loading EARCP weights...") | |
| try: | |
| wpath = hf_hub_download(repo_id=SCLM_REPO, filename="earcp_weights.pt", token=HF_TOKEN) | |
| sclm.earcp.load_state_dict(torch.load(wpath, map_location="cpu"), strict=False) | |
| USE_SCLM = True | |
| print("EARCP loaded!") | |
| except: | |
| USE_SCLM = False | |
| print("Ananke ready!") | |
| history = [] | |
| def chat(message, temp=0.7, max_tok=1024): | |
| global history | |
| if not message.strip(): | |
| return "" | |
| history.append(("user", message)) | |
| prompt = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n" + SYSTEM_PROMPT + "<|eot_id|>" | |
| for role, content in history[-10:]: | |
| prompt += f"<|start_header_id|>{role}<|end_header_id|>\n\n{content}<|eot_id|>" | |
| prompt += "<|start_header_id|>assistant<|end_header_id|>\n\n" | |
| inputs = tokenizer(prompt, return_tensors="pt").to(base_model.device) | |
| if USE_SCLM: | |
| with torch.no_grad(): | |
| sclm(inputs.input_ids, inputs.attention_mask) | |
| eos = tokenizer.eos_token_id | |
| with torch.no_grad(): | |
| out = base_model.generate( | |
| inputs.input_ids, | |
| attention_mask=inputs.attention_mask, | |
| max_new_tokens=int(max_tok), | |
| temperature=float(temp), | |
| do_sample=True, | |
| top_p=0.9, | |
| repetition_penalty=1.1, | |
| pad_token_id=eos, | |
| eos_token_id=eos, | |
| ) | |
| resp = tokenizer.decode(out[0], skip_special_tokens=True) | |
| if "assistant" in resp.lower(): | |
| resp = resp.split("assistant")[-1] | |
| for t in ["<|eot_id|>", "<|end_header_id|>", "<|start_header_id|>", "user", "system", ":"]: | |
| resp = resp.replace(t, "") | |
| resp = resp.strip() or "..." | |
| history.append(("assistant", resp)) | |
| return resp | |
| def clear(): | |
| global history | |
| history = [] | |
| if USE_SCLM: | |
| sclm.reset() | |
| return "" | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# 🔮 Ananké\nAssistant IA avec mémoire contextuelle | Architecture SCLM par Mike Amega") | |
| with gr.Row(): | |
| with gr.Column(): | |
| output = gr.Textbox(label="Réponse", lines=15) | |
| inp = gr.Textbox(label="Message", lines=2, placeholder="Parle avec Ananké...") | |
| with gr.Column(): | |
| temp = gr.Slider(0.1, 1.5, 0.7, label="Créativité") | |
| tokens = gr.Slider(256, 2048, 1024, label="Longueur max") | |
| btn = gr.Button("Envoyer", variant="primary") | |
| clr = gr.Button("Effacer") | |
| btn.click(chat, [inp, temp, tokens], output) | |
| inp.submit(chat, [inp, temp, tokens], output) | |
| clr.click(clear, outputs=output) | |
| demo.queue().launch() | |