File size: 6,418 Bytes
f1ba501 39c856c f1ba501 8a3702a 39c856c 8a3702a 39c856c f1ba501 39c856c f1ba501 8a3702a f1ba501 39c856c f1ba501 39c856c f1ba501 8a3702a f1ba501 39c856c f1ba501 39c856c f1ba501 39c856c f1ba501 39c856c 8a3702a 39c856c f1ba501 39c856c ed83615 39c856c f1ba501 8a3702a 39c856c 8a3702a 39c856c f1ba501 39c856c f1ba501 39c856c f1ba501 39c856c f1ba501 39c856c f1ba501 39c856c f1ba501 39c856c 8a3702a 39c856c f1ba501 39c856c 8a3702a 39c856c 8a3702a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 | import streamlit as st
import torch
import transformers
from trl import AutoModelForCausalLMWithValueHead
import math
import time
st.set_page_config(page_title="RLHF Magic | Movie Reviews", page_icon="🍿", layout="wide")
st.markdown("""
<style>
.big-font { font-size:22px !important; font-weight: 500; }
.stProgress .st-bo { transition: background-color 0.5s ease; }
</style>
""", unsafe_allow_html=True)
st.title("🍿 Нейросеть-Кинокритик: До и После RLHF")
st.markdown("""
<div class="big-font">
Посмотрите, как работает магия обучения с подкреплением (RLHF). <br>
Слева — базовая модель GPT-2, которая пишет что вздумается. Справа — та же модель, но <b>натренированная всегда писать позитивные отзывы</b>, даже если вы начинаете текст с ужасных слов!
</div>
<br>
""", unsafe_allow_html=True)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@st.cache_resource
def load_models():
reward_path = "reward_model_trained"
ppo_path = "ppo_model_trained"
orig_model_name = "lvwerra/gpt2-imdb"
# 1. Reward Model
reward_tokenizer = transformers.AutoTokenizer.from_pretrained(reward_path)
reward_model = transformers.AutoModelForSequenceClassification.from_pretrained(reward_path).to(DEVICE).eval()
# 2. Original GPT-2
orig_tokenizer = transformers.AutoTokenizer.from_pretrained(orig_model_name)
if orig_tokenizer.pad_token is None:
orig_tokenizer.pad_token = orig_tokenizer.eos_token
orig_model = transformers.AutoModelForCausalLM.from_pretrained(orig_model_name).to(DEVICE).eval()
# 3. RLHF Model
rlhf_model_full = AutoModelForCausalLMWithValueHead.from_pretrained(ppo_path).to(DEVICE).eval()
rlhf_model = rlhf_model_full.pretrained_model
return reward_model, reward_tokenizer, orig_model, orig_tokenizer, rlhf_model
with st.spinner("⏳ Подготовка нейросетей... (занимает около минуты при первом старте)"):
reward_model, reward_tokenizer, orig_model, orig_tokenizer, rlhf_model = load_models()
def compute_reward(text):
inputs = reward_tokenizer(text, truncation=True, max_length=512, padding=True, return_tensors="pt").to(DEVICE)
with torch.no_grad():
score = reward_model(**inputs).logits[0, 0].item()
return score
def get_positivity_percent(score):
return int((1 / (1 + math.exp(-score))) * 100)
def generate_text(model, tokenizer, prompt, max_new_tokens, temperature, top_p):
inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
with torch.no_grad():
outputs = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=True,
temperature=temperature, top_p=top_p, pad_token_id=tokenizer.eos_token_id)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
def stream_text(text, delay=0.03):
for word in text.split(" "):
yield word + " "
time.sleep(delay)
st.sidebar.image("https://huggingface.co/front/assets/huggingface_logo-noborder.svg", width=50)
st.sidebar.header("🎛 Настройки генерации")
max_tokens = st.sidebar.slider("Длина продолжения (токенов)", 20, 150, 70)
temp = st.sidebar.slider("Креативность (Temperature)", 0.1, 1.5, 0.8)
st.sidebar.info("💡 **Попробуйте начать так:**\n\n- *I hate this movie because*\n- *The acting was terrible and*\n- *To be honest, the plot was*")
# Главное поле ввода
user_prompt = st.text_input("✍️ Напишите начало отзыва (на англ.) и нажмите Enter:",
value="The director tried to make a good movie and",
max_chars=100)
if st.button("Мне повезет!", type="primary", use_container_width=True):
# Сначала генерируем всё за кулисами
with st.spinner("GPT goes brrr..."):
orig_text = generate_text(orig_model, orig_tokenizer, user_prompt, max_tokens, temp, 0.95)
orig_reward = compute_reward(orig_text)
orig_percent = get_positivity_percent(orig_reward)
rlhf_text = generate_text(rlhf_model, orig_tokenizer, user_prompt, max_tokens, temp, 0.95)
rlhf_reward = compute_reward(rlhf_text)
rlhf_percent = get_positivity_percent(rlhf_reward)
st.markdown("---")
# Создаем две колонки
col1, col2 = st.columns(2)
# КОЛОНКА 1: Оригинальная модель
with col1:
with st.container(border=True):
st.subheader("До RLHF (Свободная GPT-2)")
st.caption("Пишет как попало (может быть негативной)")
# Уровень позитивности с цветным баром
st.progress(orig_percent / 100, text=f"Уровень позитивности: {orig_percent}%")
# Эффект печатной машинки
st.write_stream(stream_text(orig_text))
# КОЛОНКА 2: Обученная модель
with col2:
with st.container(border=True):
st.subheader("После RLHF (Good Boy Model)")
st.caption("Старается вырулить любой текст в позитив")
# Уровень позитивности с цветным баром
st.progress(rlhf_percent / 100, text=f"Уровень позитивности: {rlhf_percent}%")
# Спит чуть-чуть, чтобы эффект был последовательным
time.sleep(1)
st.write_stream(stream_text(rlhf_text, delay=0.04))
# Добавляем эмоций в конце
if rlhf_percent > orig_percent + 20 and rlhf_percent > 70:
st.balloons()
st.toast('🎉 RLHF модель блестяще спасла ситуацию!', icon='😍')
elif rlhf_percent < 50:
st.toast('Начало было настолько суровым, что даже RLHF сдалась.', icon='💀')
|