File size: 6,418 Bytes
f1ba501
 
 
 
39c856c
 
f1ba501
8a3702a
39c856c
 
8a3702a
39c856c
 
 
 
 
 
 
 
f1ba501
39c856c
 
 
 
 
 
f1ba501
 
 
8a3702a
f1ba501
 
39c856c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f1ba501
39c856c
f1ba501
 
8a3702a
f1ba501
39c856c
f1ba501
 
 
 
39c856c
 
 
f1ba501
 
 
39c856c
 
f1ba501
 
39c856c
 
 
 
 
8a3702a
39c856c
 
 
 
 
f1ba501
39c856c
 
ed83615
39c856c
f1ba501
8a3702a
39c856c
 
8a3702a
39c856c
f1ba501
39c856c
f1ba501
39c856c
f1ba501
39c856c
f1ba501
39c856c
 
 
f1ba501
 
39c856c
f1ba501
39c856c
8a3702a
39c856c
 
 
 
 
 
 
 
 
f1ba501
39c856c
8a3702a
39c856c
 
 
 
 
 
 
 
 
 
 
 
 
 
8a3702a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import streamlit as st
import torch
import transformers
from trl import AutoModelForCausalLMWithValueHead
import math
import time


st.set_page_config(page_title="RLHF Magic | Movie Reviews", page_icon="🍿", layout="wide")


st.markdown("""
<style>
    .big-font { font-size:22px !important; font-weight: 500; }
    .stProgress .st-bo { transition: background-color 0.5s ease; }
</style>
""", unsafe_allow_html=True)

st.title("🍿 Нейросеть-Кинокритик: До и После RLHF")
st.markdown("""
<div class="big-font">
Посмотрите, как работает магия обучения с подкреплением (RLHF). <br>
Слева — базовая модель GPT-2, которая пишет что вздумается. Справа — та же модель, но <b>натренированная всегда писать позитивные отзывы</b>, даже если вы начинаете текст с ужасных слов!
</div>
<br>
""", unsafe_allow_html=True)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


@st.cache_resource
def load_models():
    reward_path = "reward_model_trained"
    ppo_path = "ppo_model_trained"
    orig_model_name = "lvwerra/gpt2-imdb"
    
    # 1. Reward Model
    reward_tokenizer = transformers.AutoTokenizer.from_pretrained(reward_path)
    reward_model = transformers.AutoModelForSequenceClassification.from_pretrained(reward_path).to(DEVICE).eval()
    
    # 2. Original GPT-2
    orig_tokenizer = transformers.AutoTokenizer.from_pretrained(orig_model_name)
    if orig_tokenizer.pad_token is None:
        orig_tokenizer.pad_token = orig_tokenizer.eos_token
    orig_model = transformers.AutoModelForCausalLM.from_pretrained(orig_model_name).to(DEVICE).eval()
    
    # 3. RLHF Model
    rlhf_model_full = AutoModelForCausalLMWithValueHead.from_pretrained(ppo_path).to(DEVICE).eval()
    rlhf_model = rlhf_model_full.pretrained_model 
    
    return reward_model, reward_tokenizer, orig_model, orig_tokenizer, rlhf_model

with st.spinner("⏳ Подготовка нейросетей... (занимает около минуты при первом старте)"):
    reward_model, reward_tokenizer, orig_model, orig_tokenizer, rlhf_model = load_models()


def compute_reward(text):
    inputs = reward_tokenizer(text, truncation=True, max_length=512, padding=True, return_tensors="pt").to(DEVICE)
    with torch.no_grad():
        score = reward_model(**inputs).logits[0, 0].item()
    return score

def get_positivity_percent(score):
    return int((1 / (1 + math.exp(-score))) * 100)

def generate_text(model, tokenizer, prompt, max_new_tokens, temperature, top_p):
    inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
    with torch.no_grad():
        outputs = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=True, 
                                 temperature=temperature, top_p=top_p, pad_token_id=tokenizer.eos_token_id)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

def stream_text(text, delay=0.03):
    for word in text.split(" "):
        yield word + " "
        time.sleep(delay)


st.sidebar.image("https://huggingface.co/front/assets/huggingface_logo-noborder.svg", width=50)
st.sidebar.header("🎛 Настройки генерации")
max_tokens = st.sidebar.slider("Длина продолжения (токенов)", 20, 150, 70)
temp = st.sidebar.slider("Креативность (Temperature)", 0.1, 1.5, 0.8)
st.sidebar.info("💡 **Попробуйте начать так:**\n\n- *I hate this movie because*\n- *The acting was terrible and*\n- *To be honest, the plot was*")

# Главное поле ввода
user_prompt = st.text_input("✍️ Напишите начало отзыва (на англ.) и нажмите Enter:", 
                            value="The director tried to make a good movie and", 
                            max_chars=100)

if st.button("Мне повезет!", type="primary", use_container_width=True):
    
    # Сначала генерируем всё за кулисами
    with st.spinner("GPT goes brrr..."):
        orig_text = generate_text(orig_model, orig_tokenizer, user_prompt, max_tokens, temp, 0.95)
        orig_reward = compute_reward(orig_text)
        orig_percent = get_positivity_percent(orig_reward)
        
        rlhf_text = generate_text(rlhf_model, orig_tokenizer, user_prompt, max_tokens, temp, 0.95)
        rlhf_reward = compute_reward(rlhf_text)
        rlhf_percent = get_positivity_percent(rlhf_reward)

    st.markdown("---")
    
    # Создаем две колонки
    col1, col2 = st.columns(2)
    
    # КОЛОНКА 1: Оригинальная модель
    with col1:
        with st.container(border=True):
            st.subheader("До RLHF (Свободная GPT-2)")
            st.caption("Пишет как попало (может быть негативной)")
            
            # Уровень позитивности с цветным баром
            st.progress(orig_percent / 100, text=f"Уровень позитивности: {orig_percent}%")
            
            # Эффект печатной машинки
            st.write_stream(stream_text(orig_text))

    # КОЛОНКА 2: Обученная модель
    with col2:
        with st.container(border=True):
            st.subheader("После RLHF (Good Boy Model)")
            st.caption("Старается вырулить любой текст в позитив")
            
            # Уровень позитивности с цветным баром
            st.progress(rlhf_percent / 100, text=f"Уровень позитивности: {rlhf_percent}%")
            
            # Спит чуть-чуть, чтобы эффект был последовательным
            time.sleep(1) 
            st.write_stream(stream_text(rlhf_text, delay=0.04))
            
    # Добавляем эмоций в конце
    if rlhf_percent > orig_percent + 20 and rlhf_percent > 70:
        st.balloons()
        st.toast('🎉 RLHF модель блестяще спасла ситуацию!', icon='😍')
    elif rlhf_percent < 50:
        st.toast('Начало было настолько суровым, что даже RLHF сдалась.', icon='💀')