Spaces:
Sleeping
Sleeping
File size: 4,772 Bytes
7bf134f b9c6b3d 7bf134f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 | import os
import jax
import jax.numpy as jnp
import flax.linen as nn
import pickle
import numpy as np
import gradio as gr
from huggingface_hub import hf_hub_download
HF_REPO = "MoEprometheus/Prometheus-base"
print("📥 Загружаем Prometheus...")
path = hf_hub_download(HF_REPO, "expert1.pkl")
with open(path, "rb") as f:
ckpt = pickle.load(f)
itos = ckpt["vocab"]
stoi = {v: k for k, v in itos.items()}
CONFIG = ckpt["config"]
encode = lambda s: [stoi.get(c, 0) for c in s]
decode = lambda l: "".join([itos.get(i, "") for i in l])
print(f"✅ Загружено — шаг {ckpt['step']}")
class PrometheusAttention(nn.Module):
n_heads: int
n_embed: int
block_size: int
dropout: float
@nn.compact
def __call__(self, x, deterministic=True):
B, T, C = x.shape
head_size = self.n_embed // self.n_heads
qkv = nn.Dense(3 * self.n_embed, use_bias=False)(x)
q, k, v = jnp.split(qkv, 3, axis=-1)
q = q.reshape(B, T, self.n_heads, head_size).transpose(0, 2, 1, 3)
k = k.reshape(B, T, self.n_heads, head_size).transpose(0, 2, 1, 3)
v = v.reshape(B, T, self.n_heads, head_size).transpose(0, 2, 1, 3)
att = (q @ k.transpose(0, 1, 3, 2)) * (head_size ** -0.5)
mask = jnp.tril(jnp.ones((T, T)))
att = jnp.where(mask == 0, -1e9, att)
att = jax.nn.softmax(att, axis=-1)
att = nn.Dropout(self.dropout)(att, deterministic=deterministic)
out = (att @ v).transpose(0, 2, 1, 3).reshape(B, T, C)
out = nn.Dense(self.n_embed)(out)
return nn.Dropout(self.dropout)(out, deterministic=deterministic)
class PrometheusMLP(nn.Module):
n_embed: int
dropout: float
@nn.compact
def __call__(self, x, deterministic=True):
x = nn.Dense(4 * self.n_embed)(x)
x = nn.gelu(x)
x = nn.Dense(self.n_embed)(x)
return nn.Dropout(self.dropout)(x, deterministic=deterministic)
class PrometheusBlock(nn.Module):
n_embed: int
n_heads: int
block_size: int
dropout: float
@nn.compact
def __call__(self, x, deterministic=True):
x = x + PrometheusAttention(
self.n_heads, self.n_embed,
self.block_size, self.dropout
)(nn.LayerNorm()(x), deterministic)
x = x + PrometheusMLP(
self.n_embed, self.dropout
)(nn.LayerNorm()(x), deterministic)
return x
class Prometheus(nn.Module):
vocab_size: int
n_embed: int
n_heads: int
n_layers: int
block_size: int
dropout: float
@nn.compact
def __call__(self, idx, training=False):
B, T = idx.shape
tok = nn.Embed(self.vocab_size, self.n_embed)(idx)
pos = nn.Embed(self.block_size, self.n_embed)(jnp.arange(T))
x = nn.Dropout(self.dropout)(tok + pos, deterministic=True)
BlockRemat = nn.remat(PrometheusBlock, static_argnums=(2,))
for _ in range(self.n_layers):
x = BlockRemat(
self.n_embed, self.n_heads,
self.block_size, self.dropout
)(x, True)
return nn.Dense(self.vocab_size)(nn.LayerNorm()(x))
model = Prometheus(
vocab_size = CONFIG["vocab_size"],
n_embed = CONFIG["n_embed"],
n_heads = CONFIG["n_heads"],
n_layers = CONFIG["n_layers"],
block_size = CONFIG["block_size"],
dropout = CONFIG["dropout"],
)
params = ckpt["params"]
def generate(prompt, max_new_tokens=80, temperature=1.1):
tokens = encode(prompt)
tokens = tokens[-(CONFIG["block_size"]-1):]
for _ in range(max_new_tokens):
x = jnp.array([tokens])
logits = model.apply(params, x, training=False)
logits = logits[0, -1, :] / temperature
top_k = 40
top_k_logits, top_k_indices = jax.lax.top_k(logits, top_k)
probs = jax.nn.softmax(top_k_logits)
chosen = int(jax.random.categorical(
jax.random.PRNGKey(np.random.randint(0, 99999)),
jnp.log(probs)
))
next_token = int(top_k_indices[chosen])
tokens.append(next_token)
return decode(tokens)
def chat(message, history):
result = generate(message, max_new_tokens=80, temperature=1.1)
# Убираем промпт из ответа
if len(result) > len(message):
answer = result[len(message):]
else:
answer = result
return answer.strip()
demo = gr.ChatInterface(
fn=chat,
title="🔥 Prometheus AI",
description="Языковая модель 1.2B параметров. Создана с нуля одним человеком.",
examples=[
"Москва —",
"Россия — это",
"Нейронная сеть — это",
"Python — язык",
],
)
demo.launch() |