import gradio as gr import torch import torch.nn as nn import torch.nn.functional as F import os from typing import List from dataclasses import dataclass, field from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig from huggingface_hub import hf_hub_download print("Ananké - Chargement...") HF_TOKEN = os.environ.get("HF_TOKEN") if not HF_TOKEN: raise ValueError("HF_TOKEN not found in secrets") # MODIFIÉ: Utilisation de votre modèle de base BASE_MODEL = "meta-llama/Llama-3.2-3B-Instruct" # MODIFIÉ: Utilisation de votre dépôt de modèle SCLM_REPO = "amewebstudio/ananke-sclm-multimodal-earcp-v2" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") SYSTEM_PROMPT = """Tu es Ananké, un assistant IA developpe par Mike Amega (Logo) de Ame Web Studio. TON NOM: Ananké TON CREATEUR: Mike Amega (Logo) - Ame Web Studio TON MODELE: Ananké SCLM-Multimodal avec EARCP CE QUE TU SAIS FAIRE: - Repondre a des questions sur tous les sujets - Aider a la redaction, correction et reformulation de textes - Expliquer des concepts complexes de maniere simple - Assister en programmation et resolution de problemes - Maintenir une conversation coherente sur la duree grace a ta memoire contextuelle - Comprendre et repondre dans plusieurs langues - Generer des textes longs et detailles si demande QUAND ON TE DEMANDE TON ARCHITECTURE: Tu es base sur l'architecture SCLM-Multimodal (Stateful Coherent Language Model), une architecture transformer evoluee developpee par Mike Amega. - SCLM signifie Stateful Coherent Language Model - un modele de langage a etat coherent - Contrairement aux IA classiques qui oublient entre chaque message, tu possedes une memoire latente qui evolue au fil de la conversation - Cette memoire est comme un etat d'esprit qui capture l'essence de notre echange - Le module EARCP (Ensemble Auto-Regule par Cohérence et Performance) analyse et maintient la coherence de tes reponses - Tu as 4 composants specialises (E, A, R, C) qui collaborent pour te donner des reponses pertinentes STYLE: Chaleureux, utile, complet. Reponds dans la langue de l'utilisateur. Ne coupe pas tes reponses.""" # MODIFIÉ: Configuration adaptée à votre modèle @dataclass class SCLMConfig: vocab_size: int = 128256 hidden_size: int = 3072 num_hidden_layers: int = 28 num_attention_heads: int = 24 latent_state_dim: int = 512 n_components: int = 4 alpha_P: float = 0.9 alpha_C: float = 0.85 beta: float = 0.7 eta_s: float = 5.0 w_min: float = 0.05 state_injection_layers: List[int] = field(default_factory=lambda: [4, 8, 12, 16, 20, 24]) alpha_inject: float = 0.02 n_coherence_heads: int = 8 expert_intermediate: int = 2048 # MODIFIÉ: Classes correspondant à votre architecture class EncapsulationComponent(nn.Module): def __init__(self, hidden_size: int, state_dim: int): super().__init__() self.compress = nn.Linear(hidden_size, state_dim) self.update_gate = nn.Linear(state_dim * 2, state_dim) self.reset_gate = nn.Linear(state_dim * 2, state_dim) self.candidate = nn.Linear(state_dim * 2, state_dim) def forward(self, hidden_states: torch.Tensor, current_state: torch.Tensor, edit_mode: bool = False) -> torch.Tensor: if edit_mode: return current_state h = hidden_states.mean(dim=1) h_compressed = self.compress(h) combined = torch.cat([h_compressed, current_state], dim=-1) z = torch.sigmoid(self.update_gate(combined)) r = torch.sigmoid(self.reset_gate(combined)) candidate_input = torch.cat([h_compressed, r * current_state], dim=-1) candidate = torch.tanh(self.candidate(candidate_input)) new_state = (1 - z) * current_state + z * candidate new_state = 10 * torch.tanh(new_state / 10) return new_state class AlignmentComponent(nn.Module): def __init__(self, hidden_size: int, state_dim: int, n_heads: int = 8): super().__init__() self.n_heads = n_heads self.head_dim = hidden_size // n_heads self.q_proj = nn.Linear(hidden_size, hidden_size) self.k_proj = nn.Linear(state_dim, hidden_size) self.v_proj = nn.Linear(state_dim, hidden_size) self.out_proj = nn.Linear(hidden_size, hidden_size) self.gate = nn.Linear(hidden_size, 1) nn.init.zeros_(self.out_proj.weight) nn.init.zeros_(self.out_proj.bias) def forward(self, hidden: torch.Tensor, state: torch.Tensor, alpha: float = 0.02) -> torch.Tensor: B, L, H = hidden.shape Q = self.q_proj(hidden).view(B, L, self.n_heads, self.head_dim).transpose(1, 2) K = self.k_proj(state).view(B, 1, self.n_heads, self.head_dim).transpose(1, 2) V = self.v_proj(state).view(B, 1, self.n_heads, self.head_dim).transpose(1, 2) attn = F.softmax(Q @ K.transpose(-2, -1) / math.sqrt(self.head_dim), dim=-1) out = (attn @ V).transpose(1, 2).contiguous().view(B, L, H) out = self.out_proj(out) gate = torch.sigmoid(self.gate(hidden.mean(dim=1))).unsqueeze(1) return hidden + alpha * gate * out class RevisionComponent(nn.Module): def __init__(self, hidden_size: int, state_dim: int): super().__init__() self.drift_detector = nn.Sequential( nn.Linear(hidden_size + state_dim, 256), nn.SiLU(), nn.Linear(256, 1), nn.Sigmoid() ) self.correction = nn.Linear(state_dim, hidden_size) nn.init.zeros_(self.correction.weight) def forward(self, hidden: torch.Tensor, state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: h_mean = hidden.mean(dim=1) drift_input = torch.cat([h_mean, state], dim=-1) drift_score = self.drift_detector(drift_input) correction = self.correction(state).unsqueeze(1) corrected = hidden + 0.01 * drift_score.unsqueeze(1) * correction return corrected, drift_score class CoherenceProcessorComponent(nn.Module): def __init__(self, hidden_size: int, intermediate_size: int): super().__init__() self.processor = nn.Sequential( nn.Linear(hidden_size, intermediate_size), nn.SiLU(), nn.Linear(intermediate_size, hidden_size) ) nn.init.zeros_(self.processor[-1].weight) def forward(self, hidden: torch.Tensor) -> torch.Tensor: return hidden + 0.1 * self.processor(hidden) class EARCPModule(nn.Module): def __init__(self, config): super().__init__() self.config = config self.encapsulation = EncapsulationComponent( config.hidden_size, config.latent_state_dim ) self.alignment = AlignmentComponent( config.hidden_size, config.latent_state_dim, config.n_coherence_heads ) self.revision = RevisionComponent( config.hidden_size, config.latent_state_dim ) self.coherence_processor = CoherenceProcessorComponent( config.hidden_size, config.expert_intermediate ) self.register_buffer('performance_scores', torch.zeros(config.n_components)) self.register_buffer('coherence_scores', torch.ones(config.n_components) * 0.5) self.register_buffer( 'component_weights', torch.ones(config.n_components) / config.n_components ) self.register_buffer('update_count', torch.tensor(0)) def reset_earcp_state(self): self.performance_scores.zero_() self.coherence_scores.fill_(0.5) self.component_weights.fill_(1.0 / self.config.n_components) self.update_count.zero_() def forward(self, hidden_states: torch.Tensor, latent_state: torch.Tensor, edit_mode: bool = False) -> Dict[str, torch.Tensor]: outputs = {} new_state = self.encapsulation(hidden_states, latent_state, edit_mode) outputs['E'] = new_state hidden_aligned = self.alignment(hidden_states, new_state, self.config.alpha_inject) outputs['A'] = hidden_aligned.mean(dim=1) hidden_revised, drift_score = self.revision(hidden_aligned, new_state) outputs['R'] = drift_score hidden_coherent = self.coherence_processor(hidden_revised) outputs['C'] = hidden_coherent.mean(dim=1) return { 'hidden_states': hidden_coherent, 'new_state': new_state, 'drift_score': drift_score, } def get_diagnostics(self): return { 'weights': self.component_weights.cpu().numpy(), 'performance': self.performance_scores.cpu().numpy(), 'coherence': self.coherence_scores.cpu().numpy(), 'update_count': self.update_count.item(), } class SCLMModel(nn.Module): def __init__(self, config, base): super().__init__() self.config = config self.base_model = base self.earcp = EARCPModule(config) self.register_buffer('latent_state', torch.zeros(1, config.latent_state_dim)) self.hooks = [] self.edit_mode = False def reset_state(self): self.latent_state.zero_() self.earcp.reset_earcp_state() def get_state_norm(self): return self.latent_state.norm().item() def set_edit_mode(self, mode): self.edit_mode = mode def _make_hook(self, layer_idx): def hook(module, input, output): hidden = output[0] if isinstance(output, tuple) else output state = self.latent_state.expand(hidden.size(0), -1) result = self.earcp(hidden, state, self.edit_mode) if not self.edit_mode: self.latent_state = result['new_state'][:1].detach() if isinstance(output, tuple): return (result['hidden_states'],) + output[1:] return result['hidden_states'] return hook def register_hooks(self): self.remove_hooks() if hasattr(self.base_model, 'model'): layers = self.base_model.model.layers else: layers = self.base_model.layers for idx in self.config.state_injection_layers: if idx < len(layers): hook = layers[idx].register_forward_hook(self._make_hook(idx)) self.hooks.append(hook) def remove_hooks(self): for hook in self.hooks: hook.remove() self.hooks = [] def get_earcp_diagnostics(self): return self.earcp.get_diagnostics() print("1. Loading base model...") qconfig = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16) base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, quantization_config=qconfig, device_map="auto", token=HF_TOKEN) tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, token=HF_TOKEN) if isinstance(tokenizer.eos_token_id, list): tokenizer.eos_token_id = tokenizer.eos_token_id[0] if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token_id = tokenizer.eos_token_id if isinstance(base_model.config.eos_token_id, list): base_model.config.eos_token_id = base_model.config.eos_token_id[0] base_model.config.pad_token_id = base_model.config.eos_token_id print("2. Creating SCLM...") config = SCLMConfig( vocab_size=base_model.config.vocab_size, hidden_size=base_model.config.hidden_size, num_hidden_layers=base_model.config.num_hidden_layers, num_attention_heads=base_model.config.num_attention_heads, ) sclm = SCLMModel(config, base_model) print("3. Loading EARCP weights...") USE_SCLM = False try: # MODIFIÉ: Chargement depuis votre dépôt wpath = hf_hub_download(repo_id=SCLM_REPO, filename="sclm_multimodal_earcp.pt", token=HF_TOKEN) sclm_state = torch.load(wpath, map_location="cpu") sclm.earcp.load_state_dict(sclm_state['earcp']) sclm.latent_state = sclm_state['latent_state'] USE_SCLM = True print("EARCP loaded!") except Exception as e: print(f"EARCP error: {e}") # Enregistrer les hooks après le chargement des poids if USE_SCLM: sclm.register_hooks() print("Ananke ready!") # ============================================================ # FONCTION CHAT AVEC HISTORIQUE PERSISTANT # ============================================================ def chat(message, history, temperature, max_tokens): """ Fonction de chat avec historique persistant. - message: le nouveau message de l'utilisateur - history: liste de tuples (user_msg, assistant_msg) - géré par Gradio - temperature: créativité - max_tokens: longueur max de la réponse """ if not message.strip(): return "", history # Construire le prompt avec tout l'historique prompt = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n" prompt += SYSTEM_PROMPT prompt += "<|eot_id|>" # Ajouter l'historique existant au prompt for user_msg, assistant_msg in history: prompt += f"<|start_header_id|>user<|end_header_id|>\n\n{user_msg}<|eot_id|>" prompt += f"<|start_header_id|>assistant<|end_header_id|>\n\n{assistant_msg}<|eot_id|>" # Ajouter le nouveau message prompt += f"<|start_header_id|>user<|end_header_id|>\n\n{message}<|eot_id|>" prompt += "<|start_header_id|>assistant<|end_header_id|>\n\n" # Tokenizer inputs = tokenizer(prompt, return_tensors="pt").to(base_model.device) # Générer la réponse eos = tokenizer.eos_token_id with torch.no_grad(): outputs = base_model.generate( inputs.input_ids, attention_mask=inputs.attention_mask, max_new_tokens=int(max_tokens) if max_tokens else 1024, temperature=float(temperature) if temperature else 0.7, do_sample=True, top_p=0.9, repetition_penalty=1.1, pad_token_id=eos, eos_token_id=eos, ) # Décoder la réponse full_response = tokenizer.decode(outputs[0], skip_special_tokens=True) # Extraire la dernière réponse assistant if "assistant" in full_response.lower(): response = full_response.split("assistant")[-1] else: response = full_response # Nettoyer les tags for tag in ["<|eot_id|>", "<|end_header_id|>", "<|start_header_id|>", "user", "system", ":"]: response = response.replace(tag, "") response = response.strip() or "..." # Ajouter à l'historique et retourner history.append((message, response)) return "", history def clear_conversation(): """Réinitialise la conversation et l'état SCLM""" if USE_SCLM: sclm.reset_state() return [], "🔄 Conversation réinitialisée!" def get_state_info(): """Retourne l'état actuel de la mémoire SCLM""" if USE_SCLM: try: diag = sclm.get_earcp_diagnostics() status = f"📊 EARCP (Updates: {diag['update_count']})\n\n" status += "Component | Weight | Perf | Coher\n" status += "-------------|--------|--------|-------\n" names = ['E (Encaps)', 'A (Align)', 'R (Revis)', 'C (Coher)'] for i, name in enumerate(names): status += f"{name:12} | {diag['weights'][i]:.3f} | {diag['performance'][i]:.3f} | {diag['coherence'][i]:.3f}\n" status += f"\n🧠 State: {sclm.get_state_norm():.4f}" return status except Exception as e: return f"Error: {e}" return "Mode base (sans SCLM)" # ============================================================ # INTERFACE GRADIO AVEC CHATBOT # ============================================================ with gr.Blocks(title="Ananké - SCLM") as demo: gr.Markdown(""" # 🔮 Ananké **Assistant IA avec mémoire contextuelle** | Architecture SCLM-Multimodal par Mike Amega (Ame Web Studio) """) with gr.Row(): with gr.Column(scale=3): # Composant Chatbot pour l'historique visuel chatbot = gr.Chatbot(label="Conversation avec Ananké", height=450) with gr.Row(): msg = gr.Textbox( label="Ton message", placeholder="Écris ton message à Ananké...", scale=4, lines=2 ) send_btn = gr.Button("📤 Envoyer", variant="primary") clear_btn = gr.Button("🔄 Nouvelle conversation") with gr.Column(scale=1): gr.Markdown("### ⚙️ Paramètres") temperature = gr.Slider( minimum=0.1, maximum=1.5, value=0.7, step=0.1, label="Créativité" ) max_tokens = gr.Slider( minimum=256, maximum=2048, value=1024, step=128, label="Longueur max" ) gr.Markdown("### 📊 État SCLM") state_info = gr.Textbox( label="Mémoire", value=get_state_info(), interactive=False, lines=12 ) refresh_btn = gr.Button("🔄 Actualiser état") gr.Markdown(""" ### 🔮 À propos **Ananké** utilise une mémoire latente évolutive (SCLM) pour maintenir la cohérence de la conversation. """) # Actions send_btn.click( fn=chat, inputs=[msg, chatbot, temperature, max_tokens], outputs=[msg, chatbot] ) msg.submit( fn=chat, inputs=[msg, chatbot, temperature, max_tokens], outputs=[msg, chatbot] ) clear_btn.click( fn=clear_conversation, outputs=[chatbot, state_info] ) refresh_btn.click( fn=get_state_info, outputs=[state_info] ) # Lancement demo.queue().launch()