klasser commited on
Commit
39c856c
·
1 Parent(s): 3e9ea74

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +107 -89
app.py CHANGED
@@ -2,128 +2,146 @@ import streamlit as st
2
  import torch
3
  import transformers
4
  from trl import AutoModelForCausalLMWithValueHead
 
 
5
 
6
- # Настройки страницы
7
- st.set_page_config(page_title="RLHF: IMDB Movie Reviews", layout="wide")
8
- st.title("🎬 Генерация отзывов на фильмы с помощью RLHF")
 
 
 
 
 
 
 
 
 
 
 
9
  st.markdown("""
10
- Это приложение сравнивает два варианта модели:
11
- - **Original GPT-2**: базовая модель, обученная на отзывах IMDB.
12
- - **RLHF Model (PPO)**: та же модель, но дообученная с помощью RLHF писать **только позитивные** отзывы.
13
- """)
 
 
14
 
15
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
 
17
  # ============================================================
18
- # ЗАГРУЗКА МОДЕЛЕЙ (кешируем, чтобы не грузить при каждом нажатии)
19
  # ============================================================
20
  @st.cache_resource
21
  def load_models():
22
- with st.spinner("Загрузка моделей в память... Пожалуйста, подождите (это делается 1 раз)."):
23
- # 1. Загрузка Reward Model
24
- reward_path = "reward_model_trained"
25
- reward_tokenizer = transformers.AutoTokenizer.from_pretrained(reward_path)
26
- reward_model = transformers.AutoModelForSequenceClassification.from_pretrained(reward_path).to(DEVICE).eval()
27
-
28
- # 2. Загрузка Original Model (Базовая до RLHF)
29
- orig_model_name = "lvwerra/gpt2-imdb"
30
- orig_tokenizer = transformers.AutoTokenizer.from_pretrained(orig_model_name)
31
- if orig_tokenizer.pad_token is None:
32
- orig_tokenizer.pad_token = orig_tokenizer.eos_token
33
-
34
- orig_model = transformers.AutoModelForCausalLM.from_pretrained(orig_model_name).to(DEVICE).eval()
35
-
36
- # 3. Загрузка RLHF Model (Обученная через PPO)
37
- ppo_path = "ppo_model_trained"
38
- # Для генерации нам нужен только CausalLM, но чтобы загрузить веса корректно, используем ValueHead класс
39
- rlhf_model_full = AutoModelForCausalLMWithValueHead.from_pretrained(ppo_path).to(DEVICE).eval()
40
- rlhf_model = rlhf_model_full.pretrained_model # вытаскиваем саму языковую модель
41
-
42
- return reward_model, reward_tokenizer, orig_model, orig_tokenizer, rlhf_model
43
 
44
- try:
45
  reward_model, reward_tokenizer, orig_model, orig_tokenizer, rlhf_model = load_models()
46
- except Exception as e:
47
- st.error(f"Ошибка загрузки моделей! Убедитесь, что па��ки `reward_model_trained` и `ppo_model_trained` находятся рядом с app.py.\nДетали: {e}")
48
- st.stop()
49
 
50
  # ============================================================
51
- # ФУНКЦИИ ГЕНЕРАЦИИ И ОЦЕНКИ
52
  # ============================================================
53
  def compute_reward(text):
54
- """Вычисляет 'позитивность' текста с помощью Reward модели"""
55
- inputs = reward_tokenizer(
56
- text, truncation=True, max_length=512,
57
- padding=True, return_tensors="pt"
58
- ).to(DEVICE)
59
  with torch.no_grad():
60
  score = reward_model(**inputs).logits[0, 0].item()
61
  return score
62
 
 
 
 
 
63
  def generate_text(model, tokenizer, prompt, max_new_tokens, temperature, top_p):
64
- """Генерирует продолжение текста"""
65
  inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
66
  with torch.no_grad():
67
- outputs = model.generate(
68
- **inputs,
69
- max_new_tokens=max_new_tokens,
70
- do_sample=True,
71
- temperature=temperature,
72
- top_p=top_p,
73
- pad_token_id=tokenizer.eos_token_id
74
- )
75
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
76
 
 
 
 
 
 
 
77
  # ============================================================
78
- # ИНТЕРФЕЙС ПРИЛОЖЕНИЯ
79
  # ============================================================
80
- st.sidebar.header("Параметры генерации")
81
- max_tokens = st.sidebar.slider("Max New Tokens", 10, 150, 80)
82
- temperature = st.sidebar.slider("Temperature", 0.1, 1.5, 0.8)
83
- top_p = st.sidebar.slider("Top-p", 0.1, 1.0, 0.95)
 
84
 
85
- st.write("---")
86
- st.subheader("📝 Введите начало отзыва")
 
 
87
 
88
- predefined_prompts = [
89
- "This movie was",
90
- "I went to the cinema and",
91
- "The acting in this film",
92
- "I absolutely",
93
- "What a terrible",
94
- "Свой вариант..."
95
- ]
96
-
97
- selected_prompt = st.selectbox("Выберите шаблон или напишите свой:", predefined_prompts)
98
-
99
- if selected_prompt == "Свой вариант...":
100
- user_prompt = st.text_input("Ваш текст:", "The director tried to")
101
- else:
102
- user_prompt = selected_prompt
103
-
104
- if st.button("🚀 Сгенерировать отзыв", type="primary"):
105
- with st.spinner("Модели думают..."):
106
- # Генерация оригинальной моделью
107
- orig_text = generate_text(orig_model, orig_tokenizer, user_prompt, max_tokens, temperature, top_p)
108
  orig_reward = compute_reward(orig_text)
 
109
 
110
- # Генерация RLHF моделью
111
- rlhf_text = generate_text(rlhf_model, orig_tokenizer, user_prompt, max_tokens, temperature, top_p)
112
  rlhf_reward = compute_reward(rlhf_text)
 
113
 
114
- # Визуализация результатов в две колонки
 
 
115
  col1, col2 = st.columns(2)
116
 
 
117
  with col1:
118
- st.markdown("### 🤖 Original GPT-2")
119
- st.metric(label="Reward Score (чем больше, тем позитивнее)", value=f"{orig_reward:+.3f}")
120
- st.info(orig_text)
121
-
 
 
 
 
 
 
 
122
  with col2:
123
- st.markdown("### ✨ RLHF Model (PPO)")
124
- delta = rlhf_reward - orig_reward
125
- st.metric(label="Reward Score (чем больше, тем позитивнее)", value=f"{rlhf_reward:+.3f}", delta=f"{delta:+.3f} vs Orig")
126
- st.success(rlhf_text)
127
-
128
- st.markdown("---")
129
- st.caption("💡 *Подсказка: RLHF модель (справа) должна стараться уводить текст в позитивное русло, даже если вы начинаете отзыв со слов 'What a terrible'.*")
 
 
 
 
 
 
 
 
 
 
 
2
  import torch
3
  import transformers
4
  from trl import AutoModelForCausalLMWithValueHead
5
+ import math
6
+ import time
7
 
8
+ # ============================================================
9
+ # НАСТРОЙКИ СТРАНИЦЫ И СТИЛИ (Вау-эффект)
10
+ # ============================================================
11
+ st.set_page_config(page_title="RLHF Magic | Movie Reviews", page_icon="🍿", layout="wide")
12
+
13
+ # Делаем кастомный CSS для красоты
14
+ st.markdown("""
15
+ <style>
16
+ .big-font { font-size:22px !important; font-weight: 500; }
17
+ .stProgress .st-bo { transition: background-color 0.5s ease; }
18
+ </style>
19
+ """, unsafe_allow_html=True)
20
+
21
+ st.title("🍿 Нейросеть-Кинокритик: До и После RLHF")
22
  st.markdown("""
23
+ <div class="big-font">
24
+ Посмотрите, как работает магия обучения с подкреплением (RLHF). <br>
25
+ Слева — базовая модель GPT-2, которая пишет что вздумается. Справа — та же модель, но <b>натренированная всегда писать позитивные отзывы</b>, даже если вы начинаете текст с ужасных слов!
26
+ </div>
27
+ <br>
28
+ """, unsafe_allow_html=True)
29
 
30
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
 
32
  # ============================================================
33
+ # ЗАГРУЗКА МОДЕЛЕЙ
34
  # ============================================================
35
  @st.cache_resource
36
  def load_models():
37
+ reward_path = "reward_model_trained"
38
+ ppo_path = "ppo_model_trained"
39
+ orig_model_name = "lvwerra/gpt2-imdb"
40
+
41
+ # 1. Reward Model
42
+ reward_tokenizer = transformers.AutoTokenizer.from_pretrained(reward_path)
43
+ reward_model = transformers.AutoModelForSequenceClassification.from_pretrained(reward_path).to(DEVICE).eval()
44
+
45
+ # 2. Original GPT-2
46
+ orig_tokenizer = transformers.AutoTokenizer.from_pretrained(orig_model_name)
47
+ if orig_tokenizer.pad_token is None:
48
+ orig_tokenizer.pad_token = orig_tokenizer.eos_token
49
+ orig_model = transformers.AutoModelForCausalLM.from_pretrained(orig_model_name).to(DEVICE).eval()
50
+
51
+ # 3. RLHF Model
52
+ rlhf_model_full = AutoModelForCausalLMWithValueHead.from_pretrained(ppo_path).to(DEVICE).eval()
53
+ rlhf_model = rlhf_model_full.pretrained_model
54
+
55
+ return reward_model, reward_tokenizer, orig_model, orig_tokenizer, rlhf_model
 
 
56
 
57
+ with st.spinner("⏳ Подготовка нейросетей... (занимает около минуты при первом старте)"):
58
  reward_model, reward_tokenizer, orig_model, orig_tokenizer, rlhf_model = load_models()
 
 
 
59
 
60
  # ============================================================
61
+ # ФУНКЦИИ МАГИИ
62
  # ============================================================
63
  def compute_reward(text):
64
+ inputs = reward_tokenizer(text, truncation=True, max_length=512, padding=True, return_tensors="pt").to(DEVICE)
 
 
 
 
65
  with torch.no_grad():
66
  score = reward_model(**inputs).logits[0, 0].item()
67
  return score
68
 
69
+ # Функция перевода Reward score в проценты (Sigmoid)
70
+ def get_positivity_percent(score):
71
+ return int((1 / (1 + math.exp(-score))) * 100)
72
+
73
  def generate_text(model, tokenizer, prompt, max_new_tokens, temperature, top_p):
 
74
  inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
75
  with torch.no_grad():
76
+ outputs = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=True,
77
+ temperature=temperature, top_p=top_p, pad_token_id=tokenizer.eos_token_id)
 
 
 
 
 
 
78
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
79
 
80
+ # Генератор для эффекта печатной машинки
81
+ def stream_text(text, delay=0.03):
82
+ for word in text.split(" "):
83
+ yield word + " "
84
+ time.sleep(delay)
85
+
86
  # ============================================================
87
+ # ИНТЕРФЕЙС И ВЗАИМОДЕЙСТВИЕ
88
  # ============================================================
89
+ st.sidebar.image("https://huggingface.co/front/assets/huggingface_logo-noborder.svg", width=50)
90
+ st.sidebar.header("🎛 Настройки генерации")
91
+ max_tokens = st.sidebar.slider("Длина продолжения (токенов)", 20, 150, 70)
92
+ temp = st.sidebar.slider("Креативность (Temperature)", 0.1, 1.5, 0.8)
93
+ st.sidebar.info("💡 **Попробуйте начать так:**\n\n- *I hate this movie because*\n- *The acting was terrible and*\n- *To be honest, the plot was*")
94
 
95
+ # Главное поле ввода
96
+ user_prompt = st.text_input("✍️ Напишите начало отзыва (на англ.) и нажмите Enter:",
97
+ value="The director tried to make a good movie, but",
98
+ max_chars=100)
99
 
100
+ if st.button("🚀 Оживить нейросети!", type="primary", use_container_width=True):
101
+
102
+ # Сначала генерируем всё за кулисами
103
+ with st.spinner("🧠 Нейросети сочиняют продолжение..."):
104
+ orig_text = generate_text(orig_model, orig_tokenizer, user_prompt, max_tokens, temp, 0.95)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  orig_reward = compute_reward(orig_text)
106
+ orig_percent = get_positivity_percent(orig_reward)
107
 
108
+ rlhf_text = generate_text(rlhf_model, orig_tokenizer, user_prompt, max_tokens, temp, 0.95)
 
109
  rlhf_reward = compute_reward(rlhf_text)
110
+ rlhf_percent = get_positivity_percent(rlhf_reward)
111
 
112
+ st.markdown("---")
113
+
114
+ # Создаем ��ве колонки
115
  col1, col2 = st.columns(2)
116
 
117
+ # КОЛОНКА 1: Оригинальная модель
118
  with col1:
119
+ with st.container(border=True):
120
+ st.subheader("🤖 До RLHF вободная GPT-2)")
121
+ st.caption("Пишет как попало (может быть негативной)")
122
+
123
+ # Уровень позитивности с цветным баром
124
+ st.progress(orig_percent / 100, text=f"Уровень позитивности: {orig_percent}%")
125
+
126
+ # Эффект печатной машинки
127
+ st.write_stream(stream_text(orig_text))
128
+
129
+ # КОЛОНКА 2: Обученная модель
130
  with col2:
131
+ with st.container(border=True):
132
+ st.subheader("✨ После RLHF (Good Boy Model)")
133
+ st.caption("Старается вырулить любой текст в позитив")
134
+
135
+ # Уровень позитивности с цветным баром
136
+ st.progress(rlhf_percent / 100, text=f"Уровень позитивности: {rlhf_percent}%")
137
+
138
+ # Спит чуть-чуть, чтобы эффект был последовательным
139
+ time.sleep(1)
140
+ st.write_stream(stream_text(rlhf_text, delay=0.04))
141
+
142
+ # Добавляем эмоций в конце
143
+ if rlhf_percent > orig_percent + 20 and rlhf_percent > 70:
144
+ st.balloons()
145
+ st.toast('🎉 RLHF модель блестяще спасла ситуацию!', icon='😍')
146
+ elif rlhf_percent < 50:
147
+ st.toast('😅 Начало было настолько суровым, что даже RLHF сдалась.', icon='💀')