import os import jax import jax.numpy as jnp import flax.linen as nn import pickle import numpy as np import gradio as gr from huggingface_hub import hf_hub_download HF_REPO = "MoEprometheus/Prometheus-base" print("📥 Загружаем Prometheus...") path = hf_hub_download(HF_REPO, "expert1.pkl") with open(path, "rb") as f: ckpt = pickle.load(f) itos = ckpt["vocab"] stoi = {v: k for k, v in itos.items()} CONFIG = ckpt["config"] encode = lambda s: [stoi.get(c, 0) for c in s] decode = lambda l: "".join([itos.get(i, "") for i in l]) print(f"✅ Загружено — шаг {ckpt['step']}") class PrometheusAttention(nn.Module): n_heads: int n_embed: int block_size: int dropout: float @nn.compact def __call__(self, x, deterministic=True): B, T, C = x.shape head_size = self.n_embed // self.n_heads qkv = nn.Dense(3 * self.n_embed, use_bias=False)(x) q, k, v = jnp.split(qkv, 3, axis=-1) q = q.reshape(B, T, self.n_heads, head_size).transpose(0, 2, 1, 3) k = k.reshape(B, T, self.n_heads, head_size).transpose(0, 2, 1, 3) v = v.reshape(B, T, self.n_heads, head_size).transpose(0, 2, 1, 3) att = (q @ k.transpose(0, 1, 3, 2)) * (head_size ** -0.5) mask = jnp.tril(jnp.ones((T, T))) att = jnp.where(mask == 0, -1e9, att) att = jax.nn.softmax(att, axis=-1) att = nn.Dropout(self.dropout)(att, deterministic=deterministic) out = (att @ v).transpose(0, 2, 1, 3).reshape(B, T, C) out = nn.Dense(self.n_embed)(out) return nn.Dropout(self.dropout)(out, deterministic=deterministic) class PrometheusMLP(nn.Module): n_embed: int dropout: float @nn.compact def __call__(self, x, deterministic=True): x = nn.Dense(4 * self.n_embed)(x) x = nn.gelu(x) x = nn.Dense(self.n_embed)(x) return nn.Dropout(self.dropout)(x, deterministic=deterministic) class PrometheusBlock(nn.Module): n_embed: int n_heads: int block_size: int dropout: float @nn.compact def __call__(self, x, deterministic=True): x = x + PrometheusAttention( self.n_heads, self.n_embed, self.block_size, self.dropout )(nn.LayerNorm()(x), deterministic) x = x + PrometheusMLP( self.n_embed, self.dropout )(nn.LayerNorm()(x), deterministic) return x class Prometheus(nn.Module): vocab_size: int n_embed: int n_heads: int n_layers: int block_size: int dropout: float @nn.compact def __call__(self, idx, training=False): B, T = idx.shape tok = nn.Embed(self.vocab_size, self.n_embed)(idx) pos = nn.Embed(self.block_size, self.n_embed)(jnp.arange(T)) x = nn.Dropout(self.dropout)(tok + pos, deterministic=True) BlockRemat = nn.remat(PrometheusBlock, static_argnums=(2,)) for _ in range(self.n_layers): x = BlockRemat( self.n_embed, self.n_heads, self.block_size, self.dropout )(x, True) return nn.Dense(self.vocab_size)(nn.LayerNorm()(x)) model = Prometheus( vocab_size = CONFIG["vocab_size"], n_embed = CONFIG["n_embed"], n_heads = CONFIG["n_heads"], n_layers = CONFIG["n_layers"], block_size = CONFIG["block_size"], dropout = CONFIG["dropout"], ) params = ckpt["params"] def generate(prompt, max_new_tokens=80, temperature=1.1): tokens = encode(prompt) tokens = tokens[-(CONFIG["block_size"]-1):] for _ in range(max_new_tokens): x = jnp.array([tokens]) logits = model.apply(params, x, training=False) logits = logits[0, -1, :] / temperature top_k = 40 top_k_logits, top_k_indices = jax.lax.top_k(logits, top_k) probs = jax.nn.softmax(top_k_logits) chosen = int(jax.random.categorical( jax.random.PRNGKey(np.random.randint(0, 99999)), jnp.log(probs) )) next_token = int(top_k_indices[chosen]) tokens.append(next_token) return decode(tokens) def chat(message, history): result = generate(message, max_new_tokens=80, temperature=1.1) # Убираем промпт из ответа if len(result) > len(message): answer = result[len(message):] else: answer = result return answer.strip() demo = gr.ChatInterface( fn=chat, title="🔥 Prometheus AI", description="Языковая модель 1.2B параметров. Создана с нуля одним человеком.", examples=[ "Москва —", "Россия — это", "Нейронная сеть — это", "Python — язык", ], ) demo.launch()