amewebstudio commited on
Commit
4ad140d
·
verified ·
1 Parent(s): e7af545

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +504 -0
app.py ADDED
@@ -0,0 +1,504 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import os
6
+ from typing import List
7
+ from dataclasses import dataclass, field
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
9
+ from huggingface_hub import hf_hub_download
10
+
11
+ print("Ananké - Chargement...")
12
+
13
+ HF_TOKEN = os.environ.get("HF_TOKEN")
14
+ if not HF_TOKEN:
15
+ raise ValueError("HF_TOKEN not found in secrets")
16
+
17
+ # MODIFIÉ: Utilisation de votre modèle de base
18
+ BASE_MODEL = "meta-llama/Llama-3.2-3B-Instruct"
19
+ # MODIFIÉ: Utilisation de votre dépôt de modèle
20
+ SCLM_REPO = "amewebstudio/ananke-sclm-multimodal-earcp-v2"
21
+
22
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
+
24
+ SYSTEM_PROMPT = """Tu es Ananké, un assistant IA developpe par Mike Amega (Logo) de Ame Web Studio.
25
+
26
+ TON NOM: Ananké
27
+ TON CREATEUR: Mike Amega (Logo) - Ame Web Studio
28
+ TON MODELE: Ananké SCLM-Multimodal avec EARCP
29
+
30
+ CE QUE TU SAIS FAIRE:
31
+ - Repondre a des questions sur tous les sujets
32
+ - Aider a la redaction, correction et reformulation de textes
33
+ - Expliquer des concepts complexes de maniere simple
34
+ - Assister en programmation et resolution de problemes
35
+ - Maintenir une conversation coherente sur la duree grace a ta memoire contextuelle
36
+ - Comprendre et repondre dans plusieurs langues
37
+ - Generer des textes longs et detailles si demande
38
+
39
+ QUAND ON TE DEMANDE TON ARCHITECTURE:
40
+ Tu es base sur l'architecture SCLM-Multimodal (Stateful Coherent Language Model), une architecture transformer evoluee developpee par Mike Amega.
41
+ - SCLM signifie Stateful Coherent Language Model - un modele de langage a etat coherent
42
+ - Contrairement aux IA classiques qui oublient entre chaque message, tu possedes une memoire latente qui evolue au fil de la conversation
43
+ - Cette memoire est comme un etat d'esprit qui capture l'essence de notre echange
44
+ - Le module EARCP (Ensemble Auto-Regule par Cohérence et Performance) analyse et maintient la coherence de tes reponses
45
+ - Tu as 4 composants specialises (E, A, R, C) qui collaborent pour te donner des reponses pertinentes
46
+
47
+ STYLE: Chaleureux, utile, complet. Reponds dans la langue de l'utilisateur. Ne coupe pas tes reponses."""
48
+
49
+ # MODIFIÉ: Configuration adaptée à votre modèle
50
+ @dataclass
51
+ class SCLMConfig:
52
+ vocab_size: int = 128256
53
+ hidden_size: int = 3072
54
+ num_hidden_layers: int = 28
55
+ num_attention_heads: int = 24
56
+ latent_state_dim: int = 512
57
+ n_components: int = 4
58
+ alpha_P: float = 0.9
59
+ alpha_C: float = 0.85
60
+ beta: float = 0.7
61
+ eta_s: float = 5.0
62
+ w_min: float = 0.05
63
+ state_injection_layers: List[int] = field(default_factory=lambda: [4, 8, 12, 16, 20, 24])
64
+ alpha_inject: float = 0.02
65
+ n_coherence_heads: int = 8
66
+ expert_intermediate: int = 2048
67
+
68
+ # MODIFIÉ: Classes correspondant à votre architecture
69
+ class EncapsulationComponent(nn.Module):
70
+ def __init__(self, hidden_size: int, state_dim: int):
71
+ super().__init__()
72
+ self.compress = nn.Linear(hidden_size, state_dim)
73
+ self.update_gate = nn.Linear(state_dim * 2, state_dim)
74
+ self.reset_gate = nn.Linear(state_dim * 2, state_dim)
75
+ self.candidate = nn.Linear(state_dim * 2, state_dim)
76
+
77
+ def forward(self, hidden_states: torch.Tensor, current_state: torch.Tensor, edit_mode: bool = False) -> torch.Tensor:
78
+ if edit_mode:
79
+ return current_state
80
+
81
+ h = hidden_states.mean(dim=1)
82
+ h_compressed = self.compress(h)
83
+
84
+ combined = torch.cat([h_compressed, current_state], dim=-1)
85
+ z = torch.sigmoid(self.update_gate(combined))
86
+ r = torch.sigmoid(self.reset_gate(combined))
87
+
88
+ candidate_input = torch.cat([h_compressed, r * current_state], dim=-1)
89
+ candidate = torch.tanh(self.candidate(candidate_input))
90
+
91
+ new_state = (1 - z) * current_state + z * candidate
92
+ new_state = 10 * torch.tanh(new_state / 10)
93
+
94
+ return new_state
95
+
96
+ class AlignmentComponent(nn.Module):
97
+ def __init__(self, hidden_size: int, state_dim: int, n_heads: int = 8):
98
+ super().__init__()
99
+ self.n_heads = n_heads
100
+ self.head_dim = hidden_size // n_heads
101
+
102
+ self.q_proj = nn.Linear(hidden_size, hidden_size)
103
+ self.k_proj = nn.Linear(state_dim, hidden_size)
104
+ self.v_proj = nn.Linear(state_dim, hidden_size)
105
+ self.out_proj = nn.Linear(hidden_size, hidden_size)
106
+
107
+ self.gate = nn.Linear(hidden_size, 1)
108
+
109
+ nn.init.zeros_(self.out_proj.weight)
110
+ nn.init.zeros_(self.out_proj.bias)
111
+
112
+ def forward(self, hidden: torch.Tensor, state: torch.Tensor, alpha: float = 0.02) -> torch.Tensor:
113
+ B, L, H = hidden.shape
114
+
115
+ Q = self.q_proj(hidden).view(B, L, self.n_heads, self.head_dim).transpose(1, 2)
116
+ K = self.k_proj(state).view(B, 1, self.n_heads, self.head_dim).transpose(1, 2)
117
+ V = self.v_proj(state).view(B, 1, self.n_heads, self.head_dim).transpose(1, 2)
118
+
119
+ attn = F.softmax(Q @ K.transpose(-2, -1) / math.sqrt(self.head_dim), dim=-1)
120
+ out = (attn @ V).transpose(1, 2).contiguous().view(B, L, H)
121
+ out = self.out_proj(out)
122
+
123
+ gate = torch.sigmoid(self.gate(hidden.mean(dim=1))).unsqueeze(1)
124
+
125
+ return hidden + alpha * gate * out
126
+
127
+ class RevisionComponent(nn.Module):
128
+ def __init__(self, hidden_size: int, state_dim: int):
129
+ super().__init__()
130
+
131
+ self.drift_detector = nn.Sequential(
132
+ nn.Linear(hidden_size + state_dim, 256),
133
+ nn.SiLU(),
134
+ nn.Linear(256, 1),
135
+ nn.Sigmoid()
136
+ )
137
+
138
+ self.correction = nn.Linear(state_dim, hidden_size)
139
+ nn.init.zeros_(self.correction.weight)
140
+
141
+ def forward(self, hidden: torch.Tensor, state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
142
+ h_mean = hidden.mean(dim=1)
143
+ drift_input = torch.cat([h_mean, state], dim=-1)
144
+ drift_score = self.drift_detector(drift_input)
145
+
146
+ correction = self.correction(state).unsqueeze(1)
147
+ corrected = hidden + 0.01 * drift_score.unsqueeze(1) * correction
148
+
149
+ return corrected, drift_score
150
+
151
+ class CoherenceProcessorComponent(nn.Module):
152
+ def __init__(self, hidden_size: int, intermediate_size: int):
153
+ super().__init__()
154
+
155
+ self.processor = nn.Sequential(
156
+ nn.Linear(hidden_size, intermediate_size),
157
+ nn.SiLU(),
158
+ nn.Linear(intermediate_size, hidden_size)
159
+ )
160
+
161
+ nn.init.zeros_(self.processor[-1].weight)
162
+
163
+ def forward(self, hidden: torch.Tensor) -> torch.Tensor:
164
+ return hidden + 0.1 * self.processor(hidden)
165
+
166
+ class EARCPModule(nn.Module):
167
+ def __init__(self, config):
168
+ super().__init__()
169
+ self.config = config
170
+
171
+ self.encapsulation = EncapsulationComponent(
172
+ config.hidden_size, config.latent_state_dim
173
+ )
174
+
175
+ self.alignment = AlignmentComponent(
176
+ config.hidden_size, config.latent_state_dim, config.n_coherence_heads
177
+ )
178
+
179
+ self.revision = RevisionComponent(
180
+ config.hidden_size, config.latent_state_dim
181
+ )
182
+
183
+ self.coherence_processor = CoherenceProcessorComponent(
184
+ config.hidden_size, config.expert_intermediate
185
+ )
186
+
187
+ self.register_buffer('performance_scores', torch.zeros(config.n_components))
188
+ self.register_buffer('coherence_scores', torch.ones(config.n_components) * 0.5)
189
+ self.register_buffer(
190
+ 'component_weights',
191
+ torch.ones(config.n_components) / config.n_components
192
+ )
193
+
194
+ self.register_buffer('update_count', torch.tensor(0))
195
+
196
+ def reset_earcp_state(self):
197
+ self.performance_scores.zero_()
198
+ self.coherence_scores.fill_(0.5)
199
+ self.component_weights.fill_(1.0 / self.config.n_components)
200
+ self.update_count.zero_()
201
+
202
+ def forward(self, hidden_states: torch.Tensor,
203
+ latent_state: torch.Tensor,
204
+ edit_mode: bool = False) -> Dict[str, torch.Tensor]:
205
+ outputs = {}
206
+
207
+ new_state = self.encapsulation(hidden_states, latent_state, edit_mode)
208
+ outputs['E'] = new_state
209
+
210
+ hidden_aligned = self.alignment(hidden_states, new_state, self.config.alpha_inject)
211
+ outputs['A'] = hidden_aligned.mean(dim=1)
212
+
213
+ hidden_revised, drift_score = self.revision(hidden_aligned, new_state)
214
+ outputs['R'] = drift_score
215
+
216
+ hidden_coherent = self.coherence_processor(hidden_revised)
217
+ outputs['C'] = hidden_coherent.mean(dim=1)
218
+
219
+ return {
220
+ 'hidden_states': hidden_coherent,
221
+ 'new_state': new_state,
222
+ 'drift_score': drift_score,
223
+ }
224
+
225
+ def get_diagnostics(self):
226
+ return {
227
+ 'weights': self.component_weights.cpu().numpy(),
228
+ 'performance': self.performance_scores.cpu().numpy(),
229
+ 'coherence': self.coherence_scores.cpu().numpy(),
230
+ 'update_count': self.update_count.item(),
231
+ }
232
+
233
+ class SCLMModel(nn.Module):
234
+ def __init__(self, config, base):
235
+ super().__init__()
236
+ self.config = config
237
+ self.base_model = base
238
+ self.earcp = EARCPModule(config)
239
+ self.register_buffer('latent_state', torch.zeros(1, config.latent_state_dim))
240
+ self.hooks = []
241
+ self.edit_mode = False
242
+
243
+ def reset_state(self):
244
+ self.latent_state.zero_()
245
+ self.earcp.reset_earcp_state()
246
+
247
+ def get_state_norm(self):
248
+ return self.latent_state.norm().item()
249
+
250
+ def set_edit_mode(self, mode):
251
+ self.edit_mode = mode
252
+
253
+ def _make_hook(self, layer_idx):
254
+ def hook(module, input, output):
255
+ hidden = output[0] if isinstance(output, tuple) else output
256
+
257
+ state = self.latent_state.expand(hidden.size(0), -1)
258
+ result = self.earcp(hidden, state, self.edit_mode)
259
+
260
+ if not self.edit_mode:
261
+ self.latent_state = result['new_state'][:1].detach()
262
+
263
+ if isinstance(output, tuple):
264
+ return (result['hidden_states'],) + output[1:]
265
+ return result['hidden_states']
266
+
267
+ return hook
268
+
269
+ def register_hooks(self):
270
+ self.remove_hooks()
271
+
272
+ if hasattr(self.base_model, 'model'):
273
+ layers = self.base_model.model.layers
274
+ else:
275
+ layers = self.base_model.layers
276
+
277
+ for idx in self.config.state_injection_layers:
278
+ if idx < len(layers):
279
+ hook = layers[idx].register_forward_hook(self._make_hook(idx))
280
+ self.hooks.append(hook)
281
+
282
+ def remove_hooks(self):
283
+ for hook in self.hooks:
284
+ hook.remove()
285
+ self.hooks = []
286
+
287
+ def get_earcp_diagnostics(self):
288
+ return self.earcp.get_diagnostics()
289
+
290
+ print("1. Loading base model...")
291
+ qconfig = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16)
292
+ base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, quantization_config=qconfig, device_map="auto", token=HF_TOKEN)
293
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, token=HF_TOKEN)
294
+
295
+ if isinstance(tokenizer.eos_token_id, list):
296
+ tokenizer.eos_token_id = tokenizer.eos_token_id[0]
297
+ if tokenizer.pad_token is None:
298
+ tokenizer.pad_token = tokenizer.eos_token
299
+ tokenizer.pad_token_id = tokenizer.eos_token_id
300
+ if isinstance(base_model.config.eos_token_id, list):
301
+ base_model.config.eos_token_id = base_model.config.eos_token_id[0]
302
+ base_model.config.pad_token_id = base_model.config.eos_token_id
303
+
304
+ print("2. Creating SCLM...")
305
+ config = SCLMConfig(
306
+ vocab_size=base_model.config.vocab_size,
307
+ hidden_size=base_model.config.hidden_size,
308
+ num_hidden_layers=base_model.config.num_hidden_layers,
309
+ num_attention_heads=base_model.config.num_attention_heads,
310
+ )
311
+ sclm = SCLMModel(config, base_model)
312
+
313
+ print("3. Loading EARCP weights...")
314
+ USE_SCLM = False
315
+ try:
316
+ # MODIFIÉ: Chargement depuis votre dépôt
317
+ wpath = hf_hub_download(repo_id=SCLM_REPO, filename="sclm_multimodal_earcp.pt", token=HF_TOKEN)
318
+ sclm_state = torch.load(wpath, map_location="cpu")
319
+ sclm.earcp.load_state_dict(sclm_state['earcp'])
320
+ sclm.latent_state = sclm_state['latent_state']
321
+ USE_SCLM = True
322
+ print("EARCP loaded!")
323
+ except Exception as e:
324
+ print(f"EARCP error: {e}")
325
+
326
+ # Enregistrer les hooks après le chargement des poids
327
+ if USE_SCLM:
328
+ sclm.register_hooks()
329
+
330
+ print("Ananke ready!")
331
+
332
+ # ============================================================
333
+ # FONCTION CHAT AVEC HISTORIQUE PERSISTANT
334
+ # ============================================================
335
+ def chat(message, history, temperature, max_tokens):
336
+ """
337
+ Fonction de chat avec historique persistant.
338
+ - message: le nouveau message de l'utilisateur
339
+ - history: liste de tuples (user_msg, assistant_msg) - géré par Gradio
340
+ - temperature: créativité
341
+ - max_tokens: longueur max de la réponse
342
+ """
343
+ if not message.strip():
344
+ return "", history
345
+
346
+ # Construire le prompt avec tout l'historique
347
+ prompt = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n"
348
+ prompt += SYSTEM_PROMPT
349
+ prompt += "<|eot_id|>"
350
+
351
+ # Ajouter l'historique existant au prompt
352
+ for user_msg, assistant_msg in history:
353
+ prompt += f"<|start_header_id|>user<|end_header_id|>\n\n{user_msg}<|eot_id|>"
354
+ prompt += f"<|start_header_id|>assistant<|end_header_id|>\n\n{assistant_msg}<|eot_id|>"
355
+
356
+ # Ajouter le nouveau message
357
+ prompt += f"<|start_header_id|>user<|end_header_id|>\n\n{message}<|eot_id|>"
358
+ prompt += "<|start_header_id|>assistant<|end_header_id|>\n\n"
359
+
360
+ # Tokenizer
361
+ inputs = tokenizer(prompt, return_tensors="pt").to(base_model.device)
362
+
363
+ # Générer la réponse
364
+ eos = tokenizer.eos_token_id
365
+ with torch.no_grad():
366
+ outputs = base_model.generate(
367
+ inputs.input_ids,
368
+ attention_mask=inputs.attention_mask,
369
+ max_new_tokens=int(max_tokens) if max_tokens else 1024,
370
+ temperature=float(temperature) if temperature else 0.7,
371
+ do_sample=True,
372
+ top_p=0.9,
373
+ repetition_penalty=1.1,
374
+ pad_token_id=eos,
375
+ eos_token_id=eos,
376
+ )
377
+
378
+ # Décoder la réponse
379
+ full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
380
+
381
+ # Extraire la dernière réponse assistant
382
+ if "assistant" in full_response.lower():
383
+ response = full_response.split("assistant")[-1]
384
+ else:
385
+ response = full_response
386
+
387
+ # Nettoyer les tags
388
+ for tag in ["<|eot_id|>", "<|end_header_id|>", "<|start_header_id|>", "user", "system", ":"]:
389
+ response = response.replace(tag, "")
390
+ response = response.strip() or "..."
391
+
392
+ # Ajouter à l'historique et retourner
393
+ history.append((message, response))
394
+
395
+ return "", history
396
+
397
+ def clear_conversation():
398
+ """Réinitialise la conversation et l'état SCLM"""
399
+ if USE_SCLM:
400
+ sclm.reset_state()
401
+ return [], "🔄 Conversation réinitialisée!"
402
+
403
+ def get_state_info():
404
+ """Retourne l'état actuel de la mémoire SCLM"""
405
+ if USE_SCLM:
406
+ try:
407
+ diag = sclm.get_earcp_diagnostics()
408
+ status = f"📊 EARCP (Updates: {diag['update_count']})\n\n"
409
+ status += "Component | Weight | Perf | Coher\n"
410
+ status += "-------------|--------|--------|-------\n"
411
+ names = ['E (Encaps)', 'A (Align)', 'R (Revis)', 'C (Coher)']
412
+ for i, name in enumerate(names):
413
+ status += f"{name:12} | {diag['weights'][i]:.3f} | {diag['performance'][i]:.3f} | {diag['coherence'][i]:.3f}\n"
414
+ status += f"\n🧠 State: {sclm.get_state_norm():.4f}"
415
+ return status
416
+ except Exception as e:
417
+ return f"Error: {e}"
418
+ return "Mode base (sans SCLM)"
419
+
420
+ # ============================================================
421
+ # INTERFACE GRADIO AVEC CHATBOT
422
+ # ============================================================
423
+ with gr.Blocks(title="Ananké - SCLM") as demo:
424
+ gr.Markdown("""
425
+ # 🔮 Ananké
426
+ **Assistant IA avec mémoire contextuelle** | Architecture SCLM-Multimodal par Mike Amega (Ame Web Studio)
427
+ """)
428
+
429
+ with gr.Row():
430
+ with gr.Column(scale=3):
431
+ # Composant Chatbot pour l'historique visuel
432
+ chatbot = gr.Chatbot(label="Conversation avec Ananké", height=450)
433
+
434
+ with gr.Row():
435
+ msg = gr.Textbox(
436
+ label="Ton message",
437
+ placeholder="Écris ton message à Ananké...",
438
+ scale=4,
439
+ lines=2
440
+ )
441
+ send_btn = gr.Button("📤 Envoyer", variant="primary")
442
+
443
+ clear_btn = gr.Button("🔄 Nouvelle conversation")
444
+
445
+ with gr.Column(scale=1):
446
+ gr.Markdown("### ⚙️ Paramètres")
447
+ temperature = gr.Slider(
448
+ minimum=0.1,
449
+ maximum=1.5,
450
+ value=0.7,
451
+ step=0.1,
452
+ label="Créativité"
453
+ )
454
+ max_tokens = gr.Slider(
455
+ minimum=256,
456
+ maximum=2048,
457
+ value=1024,
458
+ step=128,
459
+ label="Longueur max"
460
+ )
461
+
462
+ gr.Markdown("### 📊 État SCLM")
463
+ state_info = gr.Textbox(
464
+ label="Mémoire",
465
+ value=get_state_info(),
466
+ interactive=False,
467
+ lines=12
468
+ )
469
+
470
+ refresh_btn = gr.Button("🔄 Actualiser état")
471
+
472
+ gr.Markdown("""
473
+ ### 🔮 À propos
474
+ **Ananké** utilise une mémoire
475
+ latente évolutive (SCLM) pour
476
+ maintenir la cohérence de
477
+ la conversation.
478
+ """)
479
+
480
+ # Actions
481
+ send_btn.click(
482
+ fn=chat,
483
+ inputs=[msg, chatbot, temperature, max_tokens],
484
+ outputs=[msg, chatbot]
485
+ )
486
+
487
+ msg.submit(
488
+ fn=chat,
489
+ inputs=[msg, chatbot, temperature, max_tokens],
490
+ outputs=[msg, chatbot]
491
+ )
492
+
493
+ clear_btn.click(
494
+ fn=clear_conversation,
495
+ outputs=[chatbot, state_info]
496
+ )
497
+
498
+ refresh_btn.click(
499
+ fn=get_state_info,
500
+ outputs=[state_info]
501
+ )
502
+
503
+ # Lancement
504
+ demo.queue().launch()