Prometheus-chat / app.py
MoEprometheus's picture
Update app.py
4b42203 verified
import os
import jax
import jax.numpy as jnp
import flax.linen as nn
import pickle
import numpy as np
import gradio as gr
from huggingface_hub import hf_hub_download
HF_REPO = "MoEprometheus/Prometheus-base"
print("📥 Загружаем Prometheus...")
path = hf_hub_download(HF_REPO, "expert1.pkl")
with open(path, "rb") as f:
ckpt = pickle.load(f)
itos = ckpt["vocab"]
stoi = {v: k for k, v in itos.items()}
CONFIG = ckpt["config"]
encode = lambda s: [stoi.get(c, 0) for c in s]
decode = lambda l: "".join([itos.get(i, "") for i in l])
print(f"✅ Загружено — шаг {ckpt['step']}")
class PrometheusAttention(nn.Module):
n_heads: int
n_embed: int
block_size: int
dropout: float
@nn.compact
def __call__(self, x, deterministic=True):
B, T, C = x.shape
head_size = self.n_embed // self.n_heads
qkv = nn.Dense(3 * self.n_embed, use_bias=False)(x)
q, k, v = jnp.split(qkv, 3, axis=-1)
q = q.reshape(B, T, self.n_heads, head_size).transpose(0, 2, 1, 3)
k = k.reshape(B, T, self.n_heads, head_size).transpose(0, 2, 1, 3)
v = v.reshape(B, T, self.n_heads, head_size).transpose(0, 2, 1, 3)
att = (q @ k.transpose(0, 1, 3, 2)) * (head_size ** -0.5)
mask = jnp.tril(jnp.ones((T, T)))
att = jnp.where(mask == 0, -1e9, att)
att = jax.nn.softmax(att, axis=-1)
att = nn.Dropout(self.dropout)(att, deterministic=deterministic)
out = (att @ v).transpose(0, 2, 1, 3).reshape(B, T, C)
out = nn.Dense(self.n_embed)(out)
return nn.Dropout(self.dropout)(out, deterministic=deterministic)
class PrometheusMLP(nn.Module):
n_embed: int
dropout: float
@nn.compact
def __call__(self, x, deterministic=True):
x = nn.Dense(4 * self.n_embed)(x)
x = nn.gelu(x)
x = nn.Dense(self.n_embed)(x)
return nn.Dropout(self.dropout)(x, deterministic=deterministic)
class PrometheusBlock(nn.Module):
n_embed: int
n_heads: int
block_size: int
dropout: float
@nn.compact
def __call__(self, x, deterministic=True):
x = x + PrometheusAttention(
self.n_heads, self.n_embed,
self.block_size, self.dropout
)(nn.LayerNorm()(x), deterministic)
x = x + PrometheusMLP(
self.n_embed, self.dropout
)(nn.LayerNorm()(x), deterministic)
return x
class Prometheus(nn.Module):
vocab_size: int
n_embed: int
n_heads: int
n_layers: int
block_size: int
dropout: float
@nn.compact
def __call__(self, idx, training=False):
B, T = idx.shape
tok = nn.Embed(self.vocab_size, self.n_embed)(idx)
pos = nn.Embed(self.block_size, self.n_embed)(jnp.arange(T))
x = nn.Dropout(self.dropout)(tok + pos, deterministic=True)
BlockRemat = nn.remat(PrometheusBlock, static_argnums=(2,))
for _ in range(self.n_layers):
x = BlockRemat(
self.n_embed, self.n_heads,
self.block_size, self.dropout
)(x, True)
return nn.Dense(self.vocab_size)(nn.LayerNorm()(x))
model = Prometheus(
vocab_size = CONFIG["vocab_size"],
n_embed = CONFIG["n_embed"],
n_heads = CONFIG["n_heads"],
n_layers = CONFIG["n_layers"],
block_size = CONFIG["block_size"],
dropout = CONFIG["dropout"],
)
params = ckpt["params"]
def generate(prompt, max_new_tokens=80, temperature=1.1):
tokens = encode(prompt)
tokens = tokens[-(CONFIG["block_size"]-1):]
for _ in range(max_new_tokens):
x = jnp.array([tokens])
logits = model.apply(params, x, training=False)
logits = logits[0, -1, :] / temperature
top_k = 40
top_k_logits, top_k_indices = jax.lax.top_k(logits, top_k)
probs = jax.nn.softmax(top_k_logits)
chosen = int(jax.random.categorical(
jax.random.PRNGKey(np.random.randint(0, 99999)),
jnp.log(probs)
))
next_token = int(top_k_indices[chosen])
tokens.append(next_token)
return decode(tokens)
def chat(message, history):
result = generate(message, max_new_tokens=80, temperature=1.1)
# Убираем промпт из ответа
if len(result) > len(message):
answer = result[len(message):]
else:
answer = result
return answer.strip()
demo = gr.ChatInterface(
fn=chat,
title="🔥 Prometheus AI",
description="Языковая модель 1.2B параметров. Создана с нуля одним человеком.",
examples=[
"Москва —",
"Россия — это",
"Нейронная сеть — это",
"Python — язык",
],
)
demo.launch()