amewebstudio commited on
Commit
f52ecf8
·
verified ·
1 Parent(s): 2b5598a

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +107 -196
app.py CHANGED
@@ -8,54 +8,23 @@ from dataclasses import dataclass, field
8
  from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
9
  from huggingface_hub import hf_hub_download
10
 
11
- print("="*50)
12
  print("Ananke - Chargement...")
13
- print("="*50)
14
 
15
- # ============================================================
16
- # CONFIGURATION
17
- # ============================================================
18
  HF_TOKEN = os.environ.get("HF_TOKEN")
19
  if not HF_TOKEN:
20
- raise ValueError("HF_TOKEN secret not found! Add it in Space Settings > Secrets")
21
 
22
  BASE_MODEL = "meta-llama/Llama-3.2-3B-Instruct"
23
  SCLM_REPO = "amewebstudio/sclm-modelEarcp-optionB"
24
 
25
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
- print(f"Device: {device}")
27
 
28
- # ============================================================
29
- # SYSTEM PROMPT
30
- # ============================================================
31
  SYSTEM_PROMPT = """Tu es Ananke, un assistant IA developpe par Mike Amega (Logo) de Ame Web Studio.
 
 
 
 
32
 
33
- TON NOM: Ananke
34
- TON CREATEUR: Mike Amega (Logo) - Ame Web Studio
35
- TON MODELE: Ananke
36
-
37
- CE QUE TU SAIS FAIRE:
38
- - Repondre a des questions sur tous les sujets
39
- - Aider a la redaction, correction et reformulation de textes
40
- - Expliquer des concepts complexes de maniere simple
41
- - Assister en programmation et resolution de problemes
42
- - Maintenir une conversation coherente grace a ta memoire contextuelle
43
- - Comprendre et repondre dans plusieurs langues
44
- - Generer des textes longs et detailles si demande
45
-
46
- QUAND ON TE DEMANDE TON ARCHITECTURE:
47
- Tu es base sur SCLM (Stateful Coherent Language Model), une architecture transformer evoluee par Mike Amega.
48
- - SCLM = modele de langage a etat coherent
49
- - Tu possedes une memoire latente de 384 dimensions qui evolue au fil de la conversation
50
- - Le module EARCP maintient la coherence de tes reponses
51
- - Tu as 3 experts specialises qui collaborent pour des reponses pertinentes
52
-
53
- STYLE: Chaleureux, utile, complet. Reponds dans la langue de l utilisateur. Ne coupe pas tes reponses.
54
- """
55
-
56
- # ============================================================
57
- # SCLM CLASSES
58
- # ============================================================
59
  @dataclass
60
  class SCLMConfigB:
61
  vocab_size: int = 128256
@@ -66,125 +35,85 @@ class SCLMConfigB:
66
  n_experts: int = 3
67
  expert_intermediate: int = 1536
68
  state_injection_layers: List[int] = field(default_factory=lambda: [4, 8, 12, 16, 20, 24])
69
- alpha_inject: float = 0.03
70
-
71
- class StateFFNInjector(nn.Module):
72
- def __init__(self, hidden_size, state_dim, intermediate_size):
73
- super().__init__()
74
- self.state_proj = nn.Linear(state_dim, intermediate_size)
75
- self.output_proj = nn.Linear(intermediate_size, hidden_size)
76
- self.gate = nn.Linear(hidden_size, 1)
77
- nn.init.zeros_(self.output_proj.weight)
78
-
79
- def forward(self, hidden, state, alpha=0.03):
80
- state_proj = F.silu(self.state_proj(state))
81
- state_output = self.output_proj(state_proj)
82
- gate = torch.sigmoid(self.gate(hidden.mean(dim=1, keepdim=True)))
83
- return hidden + alpha * gate * state_output.unsqueeze(1)
84
 
85
  class EncapsulationB(nn.Module):
86
  def __init__(self, hidden_size, state_dim):
87
  super().__init__()
88
- self.n_pool_heads = 4
89
- self.pool_proj = nn.Linear(hidden_size, state_dim * self.n_pool_heads)
90
- self.pool_combine = nn.Linear(state_dim * self.n_pool_heads, state_dim)
91
  self.update_gate = nn.Linear(state_dim * 2, state_dim)
92
  self.reset_gate = nn.Linear(state_dim * 2, state_dim)
93
  self.candidate = nn.Linear(state_dim * 2, state_dim)
94
  self.attn_query = nn.Linear(state_dim, hidden_size)
95
 
96
- def forward(self, hidden, state, attention_mask=None):
97
  B, T, H = hidden.shape
98
  query = self.attn_query(state)
99
- attn_scores = torch.bmm(hidden, query.unsqueeze(-1)).squeeze(-1)
100
- if attention_mask is not None:
101
- attn_scores = attn_scores.masked_fill(attention_mask == 0, float("-inf"))
102
- attn_weights = F.softmax(attn_scores, dim=-1)
103
- h_pooled = torch.bmm(attn_weights.unsqueeze(1), hidden).squeeze(1)
104
- h_proj = F.silu(self.pool_proj(h_pooled))
105
- h_proj = self.pool_combine(h_proj)
106
- combined = torch.cat([h_proj, state], dim=-1)
107
  z = torch.sigmoid(self.update_gate(combined))
108
  r = torch.sigmoid(self.reset_gate(combined))
109
- h_cand = torch.tanh(self.candidate(torch.cat([h_proj, r * state], dim=-1)))
110
- new_state = (1 - z) * state + z * h_cand
111
  return torch.tanh(new_state / 10.0) * 10.0
112
 
113
- class CoherenceExpertsB(nn.Module):
114
- def __init__(self, hidden_size, intermediate_size, n_experts=3):
115
  super().__init__()
116
- self.n_experts = n_experts
117
  self.experts = nn.ModuleList([
118
- nn.Sequential(
119
- nn.Linear(hidden_size, intermediate_size),
120
- nn.SiLU(),
121
- nn.Dropout(0.1),
122
- nn.Linear(intermediate_size, hidden_size)
123
- ) for _ in range(n_experts)
124
  ])
125
- self.router = nn.Sequential(
126
- nn.Linear(hidden_size, 128),
127
- nn.SiLU(),
128
- nn.Linear(128, n_experts)
129
- )
130
- for exp in self.experts:
131
- nn.init.zeros_(exp[-1].weight)
132
 
133
  def forward(self, hidden):
134
- router_logits = self.router(hidden.mean(dim=1))
135
- weights = F.softmax(router_logits, dim=-1)
136
- expert_outputs = torch.stack([exp(hidden) for exp in self.experts], dim=0)
137
  w = weights.T.unsqueeze(-1).unsqueeze(-1)
138
- return (w * expert_outputs).sum(dim=0)
139
 
140
- class EARCPModuleB(nn.Module):
141
  def __init__(self, config):
142
  super().__init__()
143
- H, S = config.hidden_size, config.latent_state_dim
144
- self.ffn_injectors = nn.ModuleDict({
145
- str(i): StateFFNInjector(H, S, config.expert_intermediate)
146
- for i in config.state_injection_layers
147
- })
148
- self.encapsulation = EncapsulationB(H, S)
149
- self.coherence = CoherenceExpertsB(H, config.expert_intermediate, config.n_experts)
150
-
151
- def update_state(self, hidden, state, attention_mask=None):
152
- new_state = self.encapsulation(hidden, state, attention_mask)
153
- hidden = self.coherence(hidden)
154
- return new_state, hidden
155
 
156
  class SCLMModel(nn.Module):
157
- def __init__(self, config, base_model):
158
  super().__init__()
159
  self.config = config
160
- self.base_model = base_model
161
- self.model_device = next(base_model.parameters()).device
162
- self.model_dtype = next(base_model.parameters()).dtype
163
- self.earcp = EARCPModuleB(config).to(self.model_device).to(self.model_dtype)
164
- self.latent_state = torch.zeros(1, config.latent_state_dim, device=self.model_device, dtype=self.model_dtype)
165
-
166
- def reset_state(self):
167
- self.latent_state = torch.zeros(1, self.config.latent_state_dim, device=self.model_device, dtype=self.model_dtype)
168
-
169
- def forward(self, input_ids, attention_mask=None):
170
- if attention_mask is None:
171
- attention_mask = torch.ones_like(input_ids)
172
- base_out = self.base_model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
173
- hidden = base_out.hidden_states[-1]
 
 
174
  B = hidden.size(0)
175
- if next(self.earcp.encapsulation.parameters()).device != hidden.device:
176
- self.earcp = self.earcp.to(hidden.device)
177
- state = self.latent_state.to(hidden.device, hidden.dtype).expand(B, -1)
178
- new_state, _ = self.earcp.update_state(hidden, state, attention_mask)
179
- self.latent_state = new_state.mean(dim=0, keepdim=True).detach()
180
- return base_out.logits
181
-
182
- # ============================================================
183
- # CHARGEMENT
184
- # ============================================================
185
- print("1. Chargement modele base...")
186
- quant_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16)
187
- base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, quantization_config=quant_config, device_map="auto", token=HF_TOKEN)
188
  tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, token=HF_TOKEN)
189
 
190
  if isinstance(tokenizer.eos_token_id, list):
@@ -196,110 +125,92 @@ if isinstance(base_model.config.eos_token_id, list):
196
  base_model.config.eos_token_id = base_model.config.eos_token_id[0]
197
  base_model.config.pad_token_id = base_model.config.eos_token_id
198
 
199
- print("2. Creation SCLM...")
200
  config = SCLMConfigB(
201
  vocab_size=base_model.config.vocab_size,
202
  hidden_size=base_model.config.hidden_size,
203
  num_hidden_layers=base_model.config.num_hidden_layers,
204
  num_attention_heads=base_model.config.num_attention_heads,
205
  )
206
- sclm_model = SCLMModel(config, base_model)
207
 
208
- print("3. Chargement EARCP...")
209
- USE_SCLM = False
210
  try:
211
- weights_path = hf_hub_download(repo_id=SCLM_REPO, filename="earcp_weights.pt", token=HF_TOKEN)
212
- sclm_model.earcp.load_state_dict(torch.load(weights_path, map_location="cpu"), strict=False)
213
  USE_SCLM = True
214
- print("EARCP charge!")
215
- except Exception as e:
216
- print(f"EARCP: {e}")
217
 
218
- print("Ananke pret!")
219
 
220
- # ============================================================
221
- # CHAT
222
- # ============================================================
223
- history_data = []
224
 
225
- def respond(message, chat_history, temperature, max_tokens):
226
- global history_data
227
-
228
  if not message.strip():
229
- return "", chat_history
230
-
231
- history_data.append({"role": "user", "content": message})
232
 
233
- # SYNTAXE CORRIGÉE - pas de } dans les tags
234
- prompt = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n"
235
- prompt += SYSTEM_PROMPT
236
- prompt += "<|eot_id|>"
237
 
238
- for msg in history_data[-10:]:
239
- role = msg["role"]
240
- content = msg["content"]
241
  prompt += f"<|start_header_id|>{role}<|end_header_id|>\n\n{content}<|eot_id|>"
242
-
243
  prompt += "<|start_header_id|>assistant<|end_header_id|>\n\n"
244
 
245
  inputs = tokenizer(prompt, return_tensors="pt").to(base_model.device)
246
 
247
  if USE_SCLM:
248
  with torch.no_grad():
249
- sclm_model(inputs.input_ids, inputs.attention_mask)
250
 
251
- eos_id = tokenizer.eos_token_id
252
  with torch.no_grad():
253
- outputs = base_model.generate(
254
  inputs.input_ids,
255
  attention_mask=inputs.attention_mask,
256
- max_new_tokens=int(max_tokens) if max_tokens else 512,
257
- temperature=float(temperature) if temperature else 0.7,
258
  do_sample=True,
259
  top_p=0.9,
260
  repetition_penalty=1.1,
261
- pad_token_id=eos_id,
262
- eos_token_id=eos_id,
263
  )
264
 
265
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
266
- if "assistant" in response.lower():
267
- response = response.split("assistant")[-1]
268
- for tag in ["<|eot_id|>", "<|end_header_id|>", "<|start_header_id|>", "user", "system", ":"]:
269
- response = response.replace(tag, "")
270
- response = response.strip() or "..."
271
-
272
- history_data.append({"role": "assistant", "content": response})
273
- chat_history.append((message, response))
274
 
275
- return "", chat_history
 
276
 
277
- def clear_chat():
278
- global history_data
279
- history_data = []
280
  if USE_SCLM:
281
- sclm_model.reset_state()
282
- return []
283
 
284
- # ============================================================
285
- # INTERFACE
286
- # ============================================================
287
- with gr.Blocks(title="Ananke") as demo:
288
- gr.Markdown("# 🔮 Ananké\n**Assistant IA avec mémoire contextuelle** | Architecture SCLM par Mike Amega")
289
-
290
- chatbot = gr.Chatbot(height=450)
291
- msg = gr.Textbox(label="Message", placeholder="Parle avec Ananke...", lines=2)
292
-
293
- with gr.Row():
294
- temp = gr.Slider(0.1, 1.5, value=0.7, label="Creativite")
295
- tokens = gr.Slider(100, 1024, value=512, label="Longueur max")
296
 
297
  with gr.Row():
298
- send = gr.Button("Envoyer", variant="primary")
299
- clear = gr.Button("Effacer")
300
-
301
- send.click(respond, [msg, chatbot, temp, tokens], [msg, chatbot])
302
- msg.submit(respond, [msg, chatbot, temp, tokens], [msg, chatbot])
303
- clear.click(clear_chat, outputs=[chatbot])
304
-
305
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
 
 
 
 
 
 
8
  from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
9
  from huggingface_hub import hf_hub_download
10
 
 
11
  print("Ananke - Chargement...")
 
12
 
 
 
 
13
  HF_TOKEN = os.environ.get("HF_TOKEN")
14
  if not HF_TOKEN:
15
+ raise ValueError("HF_TOKEN not found in secrets")
16
 
17
  BASE_MODEL = "meta-llama/Llama-3.2-3B-Instruct"
18
  SCLM_REPO = "amewebstudio/sclm-modelEarcp-optionB"
19
 
20
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
21
 
 
 
 
22
  SYSTEM_PROMPT = """Tu es Ananke, un assistant IA developpe par Mike Amega (Logo) de Ame Web Studio.
23
+ TON NOM: Ananke | TON CREATEUR: Mike Amega (Logo) | TON MODELE: Ananke
24
+ Tu sais: repondre aux questions, aider en redaction, expliquer des concepts, programmer, maintenir une conversation coherente.
25
+ Architecture SCLM: memoire latente 384 dimensions, module EARCP, 3 experts specialises.
26
+ Style: chaleureux, utile, complet. Reponds dans la langue de l utilisateur."""
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  @dataclass
29
  class SCLMConfigB:
30
  vocab_size: int = 128256
 
35
  n_experts: int = 3
36
  expert_intermediate: int = 1536
37
  state_injection_layers: List[int] = field(default_factory=lambda: [4, 8, 12, 16, 20, 24])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  class EncapsulationB(nn.Module):
40
  def __init__(self, hidden_size, state_dim):
41
  super().__init__()
42
+ self.pool_proj = nn.Linear(hidden_size, state_dim * 4)
43
+ self.pool_combine = nn.Linear(state_dim * 4, state_dim)
 
44
  self.update_gate = nn.Linear(state_dim * 2, state_dim)
45
  self.reset_gate = nn.Linear(state_dim * 2, state_dim)
46
  self.candidate = nn.Linear(state_dim * 2, state_dim)
47
  self.attn_query = nn.Linear(state_dim, hidden_size)
48
 
49
+ def forward(self, hidden, state, mask=None):
50
  B, T, H = hidden.shape
51
  query = self.attn_query(state)
52
+ scores = torch.bmm(hidden, query.unsqueeze(-1)).squeeze(-1)
53
+ if mask is not None:
54
+ scores = scores.masked_fill(mask == 0, -1e9)
55
+ weights = F.softmax(scores, dim=-1)
56
+ pooled = torch.bmm(weights.unsqueeze(1), hidden).squeeze(1)
57
+ proj = F.silu(self.pool_proj(pooled))
58
+ proj = self.pool_combine(proj)
59
+ combined = torch.cat([proj, state], dim=-1)
60
  z = torch.sigmoid(self.update_gate(combined))
61
  r = torch.sigmoid(self.reset_gate(combined))
62
+ cand = torch.tanh(self.candidate(torch.cat([proj, r * state], dim=-1)))
63
+ new_state = (1 - z) * state + z * cand
64
  return torch.tanh(new_state / 10.0) * 10.0
65
 
66
+ class CoherenceExperts(nn.Module):
67
+ def __init__(self, hidden_size, intermediate, n_experts=3):
68
  super().__init__()
 
69
  self.experts = nn.ModuleList([
70
+ nn.Sequential(nn.Linear(hidden_size, intermediate), nn.SiLU(), nn.Linear(intermediate, hidden_size))
71
+ for _ in range(n_experts)
 
 
 
 
72
  ])
73
+ self.router = nn.Sequential(nn.Linear(hidden_size, 64), nn.SiLU(), nn.Linear(64, n_experts))
 
 
 
 
 
 
74
 
75
  def forward(self, hidden):
76
+ logits = self.router(hidden.mean(dim=1))
77
+ weights = F.softmax(logits, dim=-1)
78
+ outputs = torch.stack([e(hidden) for e in self.experts], dim=0)
79
  w = weights.T.unsqueeze(-1).unsqueeze(-1)
80
+ return (w * outputs).sum(dim=0)
81
 
82
+ class EARCPModule(nn.Module):
83
  def __init__(self, config):
84
  super().__init__()
85
+ self.encapsulation = EncapsulationB(config.hidden_size, config.latent_state_dim)
86
+ self.coherence = CoherenceExperts(config.hidden_size, config.expert_intermediate, config.n_experts)
 
 
 
 
 
 
 
 
 
 
87
 
88
  class SCLMModel(nn.Module):
89
+ def __init__(self, config, base):
90
  super().__init__()
91
  self.config = config
92
+ self.base_model = base
93
+ dev = next(base.parameters()).device
94
+ dtype = next(base.parameters()).dtype
95
+ self.earcp = EARCPModule(config).to(dev).to(dtype)
96
+ self.state = torch.zeros(1, config.latent_state_dim, device=dev, dtype=dtype)
97
+
98
+ def reset(self):
99
+ dev = next(self.base_model.parameters()).device
100
+ dtype = next(self.base_model.parameters()).dtype
101
+ self.state = torch.zeros(1, self.config.latent_state_dim, device=dev, dtype=dtype)
102
+
103
+ def forward(self, ids, mask=None):
104
+ if mask is None:
105
+ mask = torch.ones_like(ids)
106
+ out = self.base_model(input_ids=ids, attention_mask=mask, output_hidden_states=True)
107
+ hidden = out.hidden_states[-1]
108
  B = hidden.size(0)
109
+ state = self.state.to(hidden.device, hidden.dtype).expand(B, -1)
110
+ new_state = self.earcp.encapsulation(hidden, state, mask)
111
+ self.state = new_state.mean(dim=0, keepdim=True).detach()
112
+ return out.logits
113
+
114
+ print("1. Loading base model...")
115
+ qconfig = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16)
116
+ base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, quantization_config=qconfig, device_map="auto", token=HF_TOKEN)
 
 
 
 
 
117
  tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, token=HF_TOKEN)
118
 
119
  if isinstance(tokenizer.eos_token_id, list):
 
125
  base_model.config.eos_token_id = base_model.config.eos_token_id[0]
126
  base_model.config.pad_token_id = base_model.config.eos_token_id
127
 
128
+ print("2. Creating SCLM...")
129
  config = SCLMConfigB(
130
  vocab_size=base_model.config.vocab_size,
131
  hidden_size=base_model.config.hidden_size,
132
  num_hidden_layers=base_model.config.num_hidden_layers,
133
  num_attention_heads=base_model.config.num_attention_heads,
134
  )
135
+ sclm = SCLMModel(config, base_model)
136
 
137
+ print("3. Loading EARCP weights...")
 
138
  try:
139
+ wpath = hf_hub_download(repo_id=SCLM_REPO, filename="earcp_weights.pt", token=HF_TOKEN)
140
+ sclm.earcp.load_state_dict(torch.load(wpath, map_location="cpu"), strict=False)
141
  USE_SCLM = True
142
+ print("EARCP loaded!")
143
+ except:
144
+ USE_SCLM = False
145
 
146
+ print("Ananke ready!")
147
 
148
+ history = []
 
 
 
149
 
150
+ def chat(message, temp=0.7, max_tok=1024):
151
+ global history
 
152
  if not message.strip():
153
+ return ""
 
 
154
 
155
+ history.append(("user", message))
 
 
 
156
 
157
+ prompt = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n" + SYSTEM_PROMPT + "<|eot_id|>"
158
+ for role, content in history[-10:]:
 
159
  prompt += f"<|start_header_id|>{role}<|end_header_id|>\n\n{content}<|eot_id|>"
 
160
  prompt += "<|start_header_id|>assistant<|end_header_id|>\n\n"
161
 
162
  inputs = tokenizer(prompt, return_tensors="pt").to(base_model.device)
163
 
164
  if USE_SCLM:
165
  with torch.no_grad():
166
+ sclm(inputs.input_ids, inputs.attention_mask)
167
 
168
+ eos = tokenizer.eos_token_id
169
  with torch.no_grad():
170
+ out = base_model.generate(
171
  inputs.input_ids,
172
  attention_mask=inputs.attention_mask,
173
+ max_new_tokens=int(max_tok),
174
+ temperature=float(temp),
175
  do_sample=True,
176
  top_p=0.9,
177
  repetition_penalty=1.1,
178
+ pad_token_id=eos,
179
+ eos_token_id=eos,
180
  )
181
 
182
+ resp = tokenizer.decode(out[0], skip_special_tokens=True)
183
+ if "assistant" in resp.lower():
184
+ resp = resp.split("assistant")[-1]
185
+ for t in ["<|eot_id|>", "<|end_header_id|>", "<|start_header_id|>", "user", "system", ":"]:
186
+ resp = resp.replace(t, "")
187
+ resp = resp.strip() or "..."
 
 
 
188
 
189
+ history.append(("assistant", resp))
190
+ return resp
191
 
192
+ def clear():
193
+ global history
194
+ history = []
195
  if USE_SCLM:
196
+ sclm.reset()
197
+ return ""
198
 
199
+ with gr.Blocks() as demo:
200
+ gr.Markdown("# 🔮 Ananké\nAssistant IA avec mémoire contextuelle | Architecture SCLM par Mike Amega")
 
 
 
 
 
 
 
 
 
 
201
 
202
  with gr.Row():
203
+ with gr.Column():
204
+ output = gr.Textbox(label="Réponse", lines=15)
205
+ inp = gr.Textbox(label="Message", lines=2, placeholder="Parle avec Ananké...")
206
+ with gr.Column():
207
+ temp = gr.Slider(0.1, 1.5, 0.7, label="Créativité")
208
+ tokens = gr.Slider(256, 2048, 1024, label="Longueur max")
209
+ btn = gr.Button("Envoyer", variant="primary")
210
+ clr = gr.Button("Effacer")
211
+
212
+ btn.click(chat, [inp, temp, tokens], output)
213
+ inp.submit(chat, [inp, temp, tokens], output)
214
+ clr.click(clear, outputs=output)
215
+
216
+ demo.queue().launch()