Spaces:
Sleeping
Sleeping
| import os | |
| import jax | |
| import jax.numpy as jnp | |
| import flax.linen as nn | |
| import pickle | |
| import numpy as np | |
| import gradio as gr | |
| from huggingface_hub import hf_hub_download | |
| HF_REPO = "MoEprometheus/Prometheus-base" | |
| print("📥 Загружаем Prometheus...") | |
| path = hf_hub_download(HF_REPO, "expert1.pkl") | |
| with open(path, "rb") as f: | |
| ckpt = pickle.load(f) | |
| itos = ckpt["vocab"] | |
| stoi = {v: k for k, v in itos.items()} | |
| CONFIG = ckpt["config"] | |
| encode = lambda s: [stoi.get(c, 0) for c in s] | |
| decode = lambda l: "".join([itos.get(i, "") for i in l]) | |
| print(f"✅ Загружено — шаг {ckpt['step']}") | |
| class PrometheusAttention(nn.Module): | |
| n_heads: int | |
| n_embed: int | |
| block_size: int | |
| dropout: float | |
| def __call__(self, x, deterministic=True): | |
| B, T, C = x.shape | |
| head_size = self.n_embed // self.n_heads | |
| qkv = nn.Dense(3 * self.n_embed, use_bias=False)(x) | |
| q, k, v = jnp.split(qkv, 3, axis=-1) | |
| q = q.reshape(B, T, self.n_heads, head_size).transpose(0, 2, 1, 3) | |
| k = k.reshape(B, T, self.n_heads, head_size).transpose(0, 2, 1, 3) | |
| v = v.reshape(B, T, self.n_heads, head_size).transpose(0, 2, 1, 3) | |
| att = (q @ k.transpose(0, 1, 3, 2)) * (head_size ** -0.5) | |
| mask = jnp.tril(jnp.ones((T, T))) | |
| att = jnp.where(mask == 0, -1e9, att) | |
| att = jax.nn.softmax(att, axis=-1) | |
| att = nn.Dropout(self.dropout)(att, deterministic=deterministic) | |
| out = (att @ v).transpose(0, 2, 1, 3).reshape(B, T, C) | |
| out = nn.Dense(self.n_embed)(out) | |
| return nn.Dropout(self.dropout)(out, deterministic=deterministic) | |
| class PrometheusMLP(nn.Module): | |
| n_embed: int | |
| dropout: float | |
| def __call__(self, x, deterministic=True): | |
| x = nn.Dense(4 * self.n_embed)(x) | |
| x = nn.gelu(x) | |
| x = nn.Dense(self.n_embed)(x) | |
| return nn.Dropout(self.dropout)(x, deterministic=deterministic) | |
| class PrometheusBlock(nn.Module): | |
| n_embed: int | |
| n_heads: int | |
| block_size: int | |
| dropout: float | |
| def __call__(self, x, deterministic=True): | |
| x = x + PrometheusAttention( | |
| self.n_heads, self.n_embed, | |
| self.block_size, self.dropout | |
| )(nn.LayerNorm()(x), deterministic) | |
| x = x + PrometheusMLP( | |
| self.n_embed, self.dropout | |
| )(nn.LayerNorm()(x), deterministic) | |
| return x | |
| class Prometheus(nn.Module): | |
| vocab_size: int | |
| n_embed: int | |
| n_heads: int | |
| n_layers: int | |
| block_size: int | |
| dropout: float | |
| def __call__(self, idx, training=False): | |
| B, T = idx.shape | |
| tok = nn.Embed(self.vocab_size, self.n_embed)(idx) | |
| pos = nn.Embed(self.block_size, self.n_embed)(jnp.arange(T)) | |
| x = nn.Dropout(self.dropout)(tok + pos, deterministic=True) | |
| BlockRemat = nn.remat(PrometheusBlock, static_argnums=(2,)) | |
| for _ in range(self.n_layers): | |
| x = BlockRemat( | |
| self.n_embed, self.n_heads, | |
| self.block_size, self.dropout | |
| )(x, True) | |
| return nn.Dense(self.vocab_size)(nn.LayerNorm()(x)) | |
| model = Prometheus( | |
| vocab_size = CONFIG["vocab_size"], | |
| n_embed = CONFIG["n_embed"], | |
| n_heads = CONFIG["n_heads"], | |
| n_layers = CONFIG["n_layers"], | |
| block_size = CONFIG["block_size"], | |
| dropout = CONFIG["dropout"], | |
| ) | |
| params = ckpt["params"] | |
| def generate(prompt, max_new_tokens=80, temperature=1.1): | |
| tokens = encode(prompt) | |
| tokens = tokens[-(CONFIG["block_size"]-1):] | |
| for _ in range(max_new_tokens): | |
| x = jnp.array([tokens]) | |
| logits = model.apply(params, x, training=False) | |
| logits = logits[0, -1, :] / temperature | |
| top_k = 40 | |
| top_k_logits, top_k_indices = jax.lax.top_k(logits, top_k) | |
| probs = jax.nn.softmax(top_k_logits) | |
| chosen = int(jax.random.categorical( | |
| jax.random.PRNGKey(np.random.randint(0, 99999)), | |
| jnp.log(probs) | |
| )) | |
| next_token = int(top_k_indices[chosen]) | |
| tokens.append(next_token) | |
| return decode(tokens) | |
| def chat(message, history): | |
| result = generate(message, max_new_tokens=80, temperature=1.1) | |
| # Убираем промпт из ответа | |
| if len(result) > len(message): | |
| answer = result[len(message):] | |
| else: | |
| answer = result | |
| return answer.strip() | |
| demo = gr.ChatInterface( | |
| fn=chat, | |
| title="🔥 Prometheus AI", | |
| description="Языковая модель 1.2B параметров. Создана с нуля одним человеком.", | |
| examples=[ | |
| "Москва —", | |
| "Россия — это", | |
| "Нейронная сеть — это", | |
| "Python — язык", | |
| ], | |
| ) | |
| demo.launch() |