Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import os | |
| from typing import List | |
| from dataclasses import dataclass, field | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
| from huggingface_hub import hf_hub_download | |
| print("Ananké - Chargement...") | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| if not HF_TOKEN: | |
| raise ValueError("HF_TOKEN not found in secrets") | |
| # MODIFIÉ: Utilisation de votre modèle de base | |
| BASE_MODEL = "meta-llama/Llama-3.2-3B-Instruct" | |
| # MODIFIÉ: Utilisation de votre dépôt de modèle | |
| SCLM_REPO = "amewebstudio/ananke-sclm-multimodal-earcp-v2" | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| SYSTEM_PROMPT = """Tu es Ananké, un assistant IA developpe par Mike Amega (Logo) de Ame Web Studio. | |
| TON NOM: Ananké | |
| TON CREATEUR: Mike Amega (Logo) - Ame Web Studio | |
| TON MODELE: Ananké SCLM-Multimodal avec EARCP | |
| CE QUE TU SAIS FAIRE: | |
| - Repondre a des questions sur tous les sujets | |
| - Aider a la redaction, correction et reformulation de textes | |
| - Expliquer des concepts complexes de maniere simple | |
| - Assister en programmation et resolution de problemes | |
| - Maintenir une conversation coherente sur la duree grace a ta memoire contextuelle | |
| - Comprendre et repondre dans plusieurs langues | |
| - Generer des textes longs et detailles si demande | |
| QUAND ON TE DEMANDE TON ARCHITECTURE: | |
| Tu es base sur l'architecture SCLM-Multimodal (Stateful Coherent Language Model), une architecture transformer evoluee developpee par Mike Amega. | |
| - SCLM signifie Stateful Coherent Language Model - un modele de langage a etat coherent | |
| - Contrairement aux IA classiques qui oublient entre chaque message, tu possedes une memoire latente qui evolue au fil de la conversation | |
| - Cette memoire est comme un etat d'esprit qui capture l'essence de notre echange | |
| - Le module EARCP (Ensemble Auto-Regule par Cohérence et Performance) analyse et maintient la coherence de tes reponses | |
| - Tu as 4 composants specialises (E, A, R, C) qui collaborent pour te donner des reponses pertinentes | |
| STYLE: Chaleureux, utile, complet. Reponds dans la langue de l'utilisateur. Ne coupe pas tes reponses.""" | |
| # MODIFIÉ: Configuration adaptée à votre modèle | |
| class SCLMConfig: | |
| vocab_size: int = 128256 | |
| hidden_size: int = 3072 | |
| num_hidden_layers: int = 28 | |
| num_attention_heads: int = 24 | |
| latent_state_dim: int = 512 | |
| n_components: int = 4 | |
| alpha_P: float = 0.9 | |
| alpha_C: float = 0.85 | |
| beta: float = 0.7 | |
| eta_s: float = 5.0 | |
| w_min: float = 0.05 | |
| state_injection_layers: List[int] = field(default_factory=lambda: [4, 8, 12, 16, 20, 24]) | |
| alpha_inject: float = 0.02 | |
| n_coherence_heads: int = 8 | |
| expert_intermediate: int = 2048 | |
| # MODIFIÉ: Classes correspondant à votre architecture | |
| class EncapsulationComponent(nn.Module): | |
| def __init__(self, hidden_size: int, state_dim: int): | |
| super().__init__() | |
| self.compress = nn.Linear(hidden_size, state_dim) | |
| self.update_gate = nn.Linear(state_dim * 2, state_dim) | |
| self.reset_gate = nn.Linear(state_dim * 2, state_dim) | |
| self.candidate = nn.Linear(state_dim * 2, state_dim) | |
| def forward(self, hidden_states: torch.Tensor, current_state: torch.Tensor, edit_mode: bool = False) -> torch.Tensor: | |
| if edit_mode: | |
| return current_state | |
| h = hidden_states.mean(dim=1) | |
| h_compressed = self.compress(h) | |
| combined = torch.cat([h_compressed, current_state], dim=-1) | |
| z = torch.sigmoid(self.update_gate(combined)) | |
| r = torch.sigmoid(self.reset_gate(combined)) | |
| candidate_input = torch.cat([h_compressed, r * current_state], dim=-1) | |
| candidate = torch.tanh(self.candidate(candidate_input)) | |
| new_state = (1 - z) * current_state + z * candidate | |
| new_state = 10 * torch.tanh(new_state / 10) | |
| return new_state | |
| class AlignmentComponent(nn.Module): | |
| def __init__(self, hidden_size: int, state_dim: int, n_heads: int = 8): | |
| super().__init__() | |
| self.n_heads = n_heads | |
| self.head_dim = hidden_size // n_heads | |
| self.q_proj = nn.Linear(hidden_size, hidden_size) | |
| self.k_proj = nn.Linear(state_dim, hidden_size) | |
| self.v_proj = nn.Linear(state_dim, hidden_size) | |
| self.out_proj = nn.Linear(hidden_size, hidden_size) | |
| self.gate = nn.Linear(hidden_size, 1) | |
| nn.init.zeros_(self.out_proj.weight) | |
| nn.init.zeros_(self.out_proj.bias) | |
| def forward(self, hidden: torch.Tensor, state: torch.Tensor, alpha: float = 0.02) -> torch.Tensor: | |
| B, L, H = hidden.shape | |
| Q = self.q_proj(hidden).view(B, L, self.n_heads, self.head_dim).transpose(1, 2) | |
| K = self.k_proj(state).view(B, 1, self.n_heads, self.head_dim).transpose(1, 2) | |
| V = self.v_proj(state).view(B, 1, self.n_heads, self.head_dim).transpose(1, 2) | |
| attn = F.softmax(Q @ K.transpose(-2, -1) / math.sqrt(self.head_dim), dim=-1) | |
| out = (attn @ V).transpose(1, 2).contiguous().view(B, L, H) | |
| out = self.out_proj(out) | |
| gate = torch.sigmoid(self.gate(hidden.mean(dim=1))).unsqueeze(1) | |
| return hidden + alpha * gate * out | |
| class RevisionComponent(nn.Module): | |
| def __init__(self, hidden_size: int, state_dim: int): | |
| super().__init__() | |
| self.drift_detector = nn.Sequential( | |
| nn.Linear(hidden_size + state_dim, 256), | |
| nn.SiLU(), | |
| nn.Linear(256, 1), | |
| nn.Sigmoid() | |
| ) | |
| self.correction = nn.Linear(state_dim, hidden_size) | |
| nn.init.zeros_(self.correction.weight) | |
| def forward(self, hidden: torch.Tensor, state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
| h_mean = hidden.mean(dim=1) | |
| drift_input = torch.cat([h_mean, state], dim=-1) | |
| drift_score = self.drift_detector(drift_input) | |
| correction = self.correction(state).unsqueeze(1) | |
| corrected = hidden + 0.01 * drift_score.unsqueeze(1) * correction | |
| return corrected, drift_score | |
| class CoherenceProcessorComponent(nn.Module): | |
| def __init__(self, hidden_size: int, intermediate_size: int): | |
| super().__init__() | |
| self.processor = nn.Sequential( | |
| nn.Linear(hidden_size, intermediate_size), | |
| nn.SiLU(), | |
| nn.Linear(intermediate_size, hidden_size) | |
| ) | |
| nn.init.zeros_(self.processor[-1].weight) | |
| def forward(self, hidden: torch.Tensor) -> torch.Tensor: | |
| return hidden + 0.1 * self.processor(hidden) | |
| class EARCPModule(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.config = config | |
| self.encapsulation = EncapsulationComponent( | |
| config.hidden_size, config.latent_state_dim | |
| ) | |
| self.alignment = AlignmentComponent( | |
| config.hidden_size, config.latent_state_dim, config.n_coherence_heads | |
| ) | |
| self.revision = RevisionComponent( | |
| config.hidden_size, config.latent_state_dim | |
| ) | |
| self.coherence_processor = CoherenceProcessorComponent( | |
| config.hidden_size, config.expert_intermediate | |
| ) | |
| self.register_buffer('performance_scores', torch.zeros(config.n_components)) | |
| self.register_buffer('coherence_scores', torch.ones(config.n_components) * 0.5) | |
| self.register_buffer( | |
| 'component_weights', | |
| torch.ones(config.n_components) / config.n_components | |
| ) | |
| self.register_buffer('update_count', torch.tensor(0)) | |
| def reset_earcp_state(self): | |
| self.performance_scores.zero_() | |
| self.coherence_scores.fill_(0.5) | |
| self.component_weights.fill_(1.0 / self.config.n_components) | |
| self.update_count.zero_() | |
| def forward(self, hidden_states: torch.Tensor, | |
| latent_state: torch.Tensor, | |
| edit_mode: bool = False) -> Dict[str, torch.Tensor]: | |
| outputs = {} | |
| new_state = self.encapsulation(hidden_states, latent_state, edit_mode) | |
| outputs['E'] = new_state | |
| hidden_aligned = self.alignment(hidden_states, new_state, self.config.alpha_inject) | |
| outputs['A'] = hidden_aligned.mean(dim=1) | |
| hidden_revised, drift_score = self.revision(hidden_aligned, new_state) | |
| outputs['R'] = drift_score | |
| hidden_coherent = self.coherence_processor(hidden_revised) | |
| outputs['C'] = hidden_coherent.mean(dim=1) | |
| return { | |
| 'hidden_states': hidden_coherent, | |
| 'new_state': new_state, | |
| 'drift_score': drift_score, | |
| } | |
| def get_diagnostics(self): | |
| return { | |
| 'weights': self.component_weights.cpu().numpy(), | |
| 'performance': self.performance_scores.cpu().numpy(), | |
| 'coherence': self.coherence_scores.cpu().numpy(), | |
| 'update_count': self.update_count.item(), | |
| } | |
| class SCLMModel(nn.Module): | |
| def __init__(self, config, base): | |
| super().__init__() | |
| self.config = config | |
| self.base_model = base | |
| self.earcp = EARCPModule(config) | |
| self.register_buffer('latent_state', torch.zeros(1, config.latent_state_dim)) | |
| self.hooks = [] | |
| self.edit_mode = False | |
| def reset_state(self): | |
| self.latent_state.zero_() | |
| self.earcp.reset_earcp_state() | |
| def get_state_norm(self): | |
| return self.latent_state.norm().item() | |
| def set_edit_mode(self, mode): | |
| self.edit_mode = mode | |
| def _make_hook(self, layer_idx): | |
| def hook(module, input, output): | |
| hidden = output[0] if isinstance(output, tuple) else output | |
| state = self.latent_state.expand(hidden.size(0), -1) | |
| result = self.earcp(hidden, state, self.edit_mode) | |
| if not self.edit_mode: | |
| self.latent_state = result['new_state'][:1].detach() | |
| if isinstance(output, tuple): | |
| return (result['hidden_states'],) + output[1:] | |
| return result['hidden_states'] | |
| return hook | |
| def register_hooks(self): | |
| self.remove_hooks() | |
| if hasattr(self.base_model, 'model'): | |
| layers = self.base_model.model.layers | |
| else: | |
| layers = self.base_model.layers | |
| for idx in self.config.state_injection_layers: | |
| if idx < len(layers): | |
| hook = layers[idx].register_forward_hook(self._make_hook(idx)) | |
| self.hooks.append(hook) | |
| def remove_hooks(self): | |
| for hook in self.hooks: | |
| hook.remove() | |
| self.hooks = [] | |
| def get_earcp_diagnostics(self): | |
| return self.earcp.get_diagnostics() | |
| print("1. Loading base model...") | |
| qconfig = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16) | |
| base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, quantization_config=qconfig, device_map="auto", token=HF_TOKEN) | |
| tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, token=HF_TOKEN) | |
| if isinstance(tokenizer.eos_token_id, list): | |
| tokenizer.eos_token_id = tokenizer.eos_token_id[0] | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| tokenizer.pad_token_id = tokenizer.eos_token_id | |
| if isinstance(base_model.config.eos_token_id, list): | |
| base_model.config.eos_token_id = base_model.config.eos_token_id[0] | |
| base_model.config.pad_token_id = base_model.config.eos_token_id | |
| print("2. Creating SCLM...") | |
| config = SCLMConfig( | |
| vocab_size=base_model.config.vocab_size, | |
| hidden_size=base_model.config.hidden_size, | |
| num_hidden_layers=base_model.config.num_hidden_layers, | |
| num_attention_heads=base_model.config.num_attention_heads, | |
| ) | |
| sclm = SCLMModel(config, base_model) | |
| print("3. Loading EARCP weights...") | |
| USE_SCLM = False | |
| try: | |
| # MODIFIÉ: Chargement depuis votre dépôt | |
| wpath = hf_hub_download(repo_id=SCLM_REPO, filename="sclm_multimodal_earcp.pt", token=HF_TOKEN) | |
| sclm_state = torch.load(wpath, map_location="cpu") | |
| sclm.earcp.load_state_dict(sclm_state['earcp']) | |
| sclm.latent_state = sclm_state['latent_state'] | |
| USE_SCLM = True | |
| print("EARCP loaded!") | |
| except Exception as e: | |
| print(f"EARCP error: {e}") | |
| # Enregistrer les hooks après le chargement des poids | |
| if USE_SCLM: | |
| sclm.register_hooks() | |
| print("Ananke ready!") | |
| # ============================================================ | |
| # FONCTION CHAT AVEC HISTORIQUE PERSISTANT | |
| # ============================================================ | |
| def chat(message, history, temperature, max_tokens): | |
| """ | |
| Fonction de chat avec historique persistant. | |
| - message: le nouveau message de l'utilisateur | |
| - history: liste de tuples (user_msg, assistant_msg) - géré par Gradio | |
| - temperature: créativité | |
| - max_tokens: longueur max de la réponse | |
| """ | |
| if not message.strip(): | |
| return "", history | |
| # Construire le prompt avec tout l'historique | |
| prompt = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n" | |
| prompt += SYSTEM_PROMPT | |
| prompt += "<|eot_id|>" | |
| # Ajouter l'historique existant au prompt | |
| for user_msg, assistant_msg in history: | |
| prompt += f"<|start_header_id|>user<|end_header_id|>\n\n{user_msg}<|eot_id|>" | |
| prompt += f"<|start_header_id|>assistant<|end_header_id|>\n\n{assistant_msg}<|eot_id|>" | |
| # Ajouter le nouveau message | |
| prompt += f"<|start_header_id|>user<|end_header_id|>\n\n{message}<|eot_id|>" | |
| prompt += "<|start_header_id|>assistant<|end_header_id|>\n\n" | |
| # Tokenizer | |
| inputs = tokenizer(prompt, return_tensors="pt").to(base_model.device) | |
| # Générer la réponse | |
| eos = tokenizer.eos_token_id | |
| with torch.no_grad(): | |
| outputs = base_model.generate( | |
| inputs.input_ids, | |
| attention_mask=inputs.attention_mask, | |
| max_new_tokens=int(max_tokens) if max_tokens else 1024, | |
| temperature=float(temperature) if temperature else 0.7, | |
| do_sample=True, | |
| top_p=0.9, | |
| repetition_penalty=1.1, | |
| pad_token_id=eos, | |
| eos_token_id=eos, | |
| ) | |
| # Décoder la réponse | |
| full_response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Extraire la dernière réponse assistant | |
| if "assistant" in full_response.lower(): | |
| response = full_response.split("assistant")[-1] | |
| else: | |
| response = full_response | |
| # Nettoyer les tags | |
| for tag in ["<|eot_id|>", "<|end_header_id|>", "<|start_header_id|>", "user", "system", ":"]: | |
| response = response.replace(tag, "") | |
| response = response.strip() or "..." | |
| # Ajouter à l'historique et retourner | |
| history.append((message, response)) | |
| return "", history | |
| def clear_conversation(): | |
| """Réinitialise la conversation et l'état SCLM""" | |
| if USE_SCLM: | |
| sclm.reset_state() | |
| return [], "🔄 Conversation réinitialisée!" | |
| def get_state_info(): | |
| """Retourne l'état actuel de la mémoire SCLM""" | |
| if USE_SCLM: | |
| try: | |
| diag = sclm.get_earcp_diagnostics() | |
| status = f"📊 EARCP (Updates: {diag['update_count']})\n\n" | |
| status += "Component | Weight | Perf | Coher\n" | |
| status += "-------------|--------|--------|-------\n" | |
| names = ['E (Encaps)', 'A (Align)', 'R (Revis)', 'C (Coher)'] | |
| for i, name in enumerate(names): | |
| status += f"{name:12} | {diag['weights'][i]:.3f} | {diag['performance'][i]:.3f} | {diag['coherence'][i]:.3f}\n" | |
| status += f"\n🧠 State: {sclm.get_state_norm():.4f}" | |
| return status | |
| except Exception as e: | |
| return f"Error: {e}" | |
| return "Mode base (sans SCLM)" | |
| # ============================================================ | |
| # INTERFACE GRADIO AVEC CHATBOT | |
| # ============================================================ | |
| with gr.Blocks(title="Ananké - SCLM") as demo: | |
| gr.Markdown(""" | |
| # 🔮 Ananké | |
| **Assistant IA avec mémoire contextuelle** | Architecture SCLM-Multimodal par Mike Amega (Ame Web Studio) | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| # Composant Chatbot pour l'historique visuel | |
| chatbot = gr.Chatbot(label="Conversation avec Ananké", height=450) | |
| with gr.Row(): | |
| msg = gr.Textbox( | |
| label="Ton message", | |
| placeholder="Écris ton message à Ananké...", | |
| scale=4, | |
| lines=2 | |
| ) | |
| send_btn = gr.Button("📤 Envoyer", variant="primary") | |
| clear_btn = gr.Button("🔄 Nouvelle conversation") | |
| with gr.Column(scale=1): | |
| gr.Markdown("### ⚙️ Paramètres") | |
| temperature = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.5, | |
| value=0.7, | |
| step=0.1, | |
| label="Créativité" | |
| ) | |
| max_tokens = gr.Slider( | |
| minimum=256, | |
| maximum=2048, | |
| value=1024, | |
| step=128, | |
| label="Longueur max" | |
| ) | |
| gr.Markdown("### 📊 État SCLM") | |
| state_info = gr.Textbox( | |
| label="Mémoire", | |
| value=get_state_info(), | |
| interactive=False, | |
| lines=12 | |
| ) | |
| refresh_btn = gr.Button("🔄 Actualiser état") | |
| gr.Markdown(""" | |
| ### 🔮 À propos | |
| **Ananké** utilise une mémoire | |
| latente évolutive (SCLM) pour | |
| maintenir la cohérence de | |
| la conversation. | |
| """) | |
| # Actions | |
| send_btn.click( | |
| fn=chat, | |
| inputs=[msg, chatbot, temperature, max_tokens], | |
| outputs=[msg, chatbot] | |
| ) | |
| msg.submit( | |
| fn=chat, | |
| inputs=[msg, chatbot, temperature, max_tokens], | |
| outputs=[msg, chatbot] | |
| ) | |
| clear_btn.click( | |
| fn=clear_conversation, | |
| outputs=[chatbot, state_info] | |
| ) | |
| refresh_btn.click( | |
| fn=get_state_info, | |
| outputs=[state_info] | |
| ) | |
| # Lancement | |
| demo.queue().launch() | |