amewebstudio's picture
Upload app.py with huggingface_hub
f52ecf8 verified
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
from typing import List
from dataclasses import dataclass, field
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from huggingface_hub import hf_hub_download
print("Ananke - Chargement...")
HF_TOKEN = os.environ.get("HF_TOKEN")
if not HF_TOKEN:
raise ValueError("HF_TOKEN not found in secrets")
BASE_MODEL = "meta-llama/Llama-3.2-3B-Instruct"
SCLM_REPO = "amewebstudio/sclm-modelEarcp-optionB"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SYSTEM_PROMPT = """Tu es Ananke, un assistant IA developpe par Mike Amega (Logo) de Ame Web Studio.
TON NOM: Ananke | TON CREATEUR: Mike Amega (Logo) | TON MODELE: Ananke
Tu sais: repondre aux questions, aider en redaction, expliquer des concepts, programmer, maintenir une conversation coherente.
Architecture SCLM: memoire latente 384 dimensions, module EARCP, 3 experts specialises.
Style: chaleureux, utile, complet. Reponds dans la langue de l utilisateur."""
@dataclass
class SCLMConfigB:
vocab_size: int = 128256
hidden_size: int = 3072
num_hidden_layers: int = 28
num_attention_heads: int = 24
latent_state_dim: int = 384
n_experts: int = 3
expert_intermediate: int = 1536
state_injection_layers: List[int] = field(default_factory=lambda: [4, 8, 12, 16, 20, 24])
class EncapsulationB(nn.Module):
def __init__(self, hidden_size, state_dim):
super().__init__()
self.pool_proj = nn.Linear(hidden_size, state_dim * 4)
self.pool_combine = nn.Linear(state_dim * 4, state_dim)
self.update_gate = nn.Linear(state_dim * 2, state_dim)
self.reset_gate = nn.Linear(state_dim * 2, state_dim)
self.candidate = nn.Linear(state_dim * 2, state_dim)
self.attn_query = nn.Linear(state_dim, hidden_size)
def forward(self, hidden, state, mask=None):
B, T, H = hidden.shape
query = self.attn_query(state)
scores = torch.bmm(hidden, query.unsqueeze(-1)).squeeze(-1)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
weights = F.softmax(scores, dim=-1)
pooled = torch.bmm(weights.unsqueeze(1), hidden).squeeze(1)
proj = F.silu(self.pool_proj(pooled))
proj = self.pool_combine(proj)
combined = torch.cat([proj, state], dim=-1)
z = torch.sigmoid(self.update_gate(combined))
r = torch.sigmoid(self.reset_gate(combined))
cand = torch.tanh(self.candidate(torch.cat([proj, r * state], dim=-1)))
new_state = (1 - z) * state + z * cand
return torch.tanh(new_state / 10.0) * 10.0
class CoherenceExperts(nn.Module):
def __init__(self, hidden_size, intermediate, n_experts=3):
super().__init__()
self.experts = nn.ModuleList([
nn.Sequential(nn.Linear(hidden_size, intermediate), nn.SiLU(), nn.Linear(intermediate, hidden_size))
for _ in range(n_experts)
])
self.router = nn.Sequential(nn.Linear(hidden_size, 64), nn.SiLU(), nn.Linear(64, n_experts))
def forward(self, hidden):
logits = self.router(hidden.mean(dim=1))
weights = F.softmax(logits, dim=-1)
outputs = torch.stack([e(hidden) for e in self.experts], dim=0)
w = weights.T.unsqueeze(-1).unsqueeze(-1)
return (w * outputs).sum(dim=0)
class EARCPModule(nn.Module):
def __init__(self, config):
super().__init__()
self.encapsulation = EncapsulationB(config.hidden_size, config.latent_state_dim)
self.coherence = CoherenceExperts(config.hidden_size, config.expert_intermediate, config.n_experts)
class SCLMModel(nn.Module):
def __init__(self, config, base):
super().__init__()
self.config = config
self.base_model = base
dev = next(base.parameters()).device
dtype = next(base.parameters()).dtype
self.earcp = EARCPModule(config).to(dev).to(dtype)
self.state = torch.zeros(1, config.latent_state_dim, device=dev, dtype=dtype)
def reset(self):
dev = next(self.base_model.parameters()).device
dtype = next(self.base_model.parameters()).dtype
self.state = torch.zeros(1, self.config.latent_state_dim, device=dev, dtype=dtype)
def forward(self, ids, mask=None):
if mask is None:
mask = torch.ones_like(ids)
out = self.base_model(input_ids=ids, attention_mask=mask, output_hidden_states=True)
hidden = out.hidden_states[-1]
B = hidden.size(0)
state = self.state.to(hidden.device, hidden.dtype).expand(B, -1)
new_state = self.earcp.encapsulation(hidden, state, mask)
self.state = new_state.mean(dim=0, keepdim=True).detach()
return out.logits
print("1. Loading base model...")
qconfig = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16)
base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, quantization_config=qconfig, device_map="auto", token=HF_TOKEN)
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, token=HF_TOKEN)
if isinstance(tokenizer.eos_token_id, list):
tokenizer.eos_token_id = tokenizer.eos_token_id[0]
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
if isinstance(base_model.config.eos_token_id, list):
base_model.config.eos_token_id = base_model.config.eos_token_id[0]
base_model.config.pad_token_id = base_model.config.eos_token_id
print("2. Creating SCLM...")
config = SCLMConfigB(
vocab_size=base_model.config.vocab_size,
hidden_size=base_model.config.hidden_size,
num_hidden_layers=base_model.config.num_hidden_layers,
num_attention_heads=base_model.config.num_attention_heads,
)
sclm = SCLMModel(config, base_model)
print("3. Loading EARCP weights...")
try:
wpath = hf_hub_download(repo_id=SCLM_REPO, filename="earcp_weights.pt", token=HF_TOKEN)
sclm.earcp.load_state_dict(torch.load(wpath, map_location="cpu"), strict=False)
USE_SCLM = True
print("EARCP loaded!")
except:
USE_SCLM = False
print("Ananke ready!")
history = []
def chat(message, temp=0.7, max_tok=1024):
global history
if not message.strip():
return ""
history.append(("user", message))
prompt = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n" + SYSTEM_PROMPT + "<|eot_id|>"
for role, content in history[-10:]:
prompt += f"<|start_header_id|>{role}<|end_header_id|>\n\n{content}<|eot_id|>"
prompt += "<|start_header_id|>assistant<|end_header_id|>\n\n"
inputs = tokenizer(prompt, return_tensors="pt").to(base_model.device)
if USE_SCLM:
with torch.no_grad():
sclm(inputs.input_ids, inputs.attention_mask)
eos = tokenizer.eos_token_id
with torch.no_grad():
out = base_model.generate(
inputs.input_ids,
attention_mask=inputs.attention_mask,
max_new_tokens=int(max_tok),
temperature=float(temp),
do_sample=True,
top_p=0.9,
repetition_penalty=1.1,
pad_token_id=eos,
eos_token_id=eos,
)
resp = tokenizer.decode(out[0], skip_special_tokens=True)
if "assistant" in resp.lower():
resp = resp.split("assistant")[-1]
for t in ["<|eot_id|>", "<|end_header_id|>", "<|start_header_id|>", "user", "system", ":"]:
resp = resp.replace(t, "")
resp = resp.strip() or "..."
history.append(("assistant", resp))
return resp
def clear():
global history
history = []
if USE_SCLM:
sclm.reset()
return ""
with gr.Blocks() as demo:
gr.Markdown("# 🔮 Ananké\nAssistant IA avec mémoire contextuelle | Architecture SCLM par Mike Amega")
with gr.Row():
with gr.Column():
output = gr.Textbox(label="Réponse", lines=15)
inp = gr.Textbox(label="Message", lines=2, placeholder="Parle avec Ananké...")
with gr.Column():
temp = gr.Slider(0.1, 1.5, 0.7, label="Créativité")
tokens = gr.Slider(256, 2048, 1024, label="Longueur max")
btn = gr.Button("Envoyer", variant="primary")
clr = gr.Button("Effacer")
btn.click(chat, [inp, temp, tokens], output)
inp.submit(chat, [inp, temp, tokens], output)
clr.click(clear, outputs=output)
demo.queue().launch()