Update app.py
Browse files
app.py
CHANGED
|
@@ -8,8 +8,9 @@ from torch import tensor
|
|
| 8 |
|
| 9 |
import joblib
|
| 10 |
from dataclasses import dataclass
|
| 11 |
-
from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification
|
| 12 |
import json
|
|
|
|
| 13 |
|
| 14 |
from preprocessing import predict_review, data_preprocessing_hard
|
| 15 |
from model_lstm import LSTMClassifier
|
|
@@ -173,17 +174,39 @@ elif selected_model == "Оценка степени токсичности по
|
|
| 173 |
|
| 174 |
|
| 175 |
|
| 176 |
-
# Генерация текста GPT
|
| 177 |
elif selected_model == "Генерация текста GPT-моделью по пользовательскому prompt":
|
| 178 |
st.title("""
|
| 179 |
-
|
| 180 |
""")
|
| 181 |
|
| 182 |
st.write("""
|
| 183 |
-
Для генерации текста используется предобученная сеть
|
|
|
|
| 184 |
""")
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
import joblib
|
| 10 |
from dataclasses import dataclass
|
| 11 |
+
from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification, GPT2LMHeadModel, GPT2Tokenizer
|
| 12 |
import json
|
| 13 |
+
import os
|
| 14 |
|
| 15 |
from preprocessing import predict_review, data_preprocessing_hard
|
| 16 |
from model_lstm import LSTMClassifier
|
|
|
|
| 174 |
|
| 175 |
|
| 176 |
|
| 177 |
+
# Генерация текста GPT
|
| 178 |
elif selected_model == "Генерация текста GPT-моделью по пользовательскому prompt":
|
| 179 |
st.title("""
|
| 180 |
+
Нейросетевой гороскоп Якубовский-Дьяченко
|
| 181 |
""")
|
| 182 |
|
| 183 |
st.write("""
|
| 184 |
+
Для генерации текста используется предобученная сеть GPT2. Дообучение проходило на гороскопах.
|
| 185 |
+
Общая длина текста для обучения 37 001 887 слов.
|
| 186 |
""")
|
| 187 |
+
user_text_input = st.text_area('Введите информацию о себе для формиорования гороскопа:')
|
| 188 |
+
|
| 189 |
+
# GPT2
|
| 190 |
+
model_path = "model.safetensors"
|
| 191 |
+
huggingface_token = os.getenv("HF_TOKEN")
|
| 192 |
+
model = GPT2LMHeadModel.from_pretrained(model_path, token=huggingface_token)
|
| 193 |
+
tokenizer = GPT2Tokenizer.from_pretrained(model_path, token=huggingface_token)
|
| 194 |
+
|
| 195 |
+
if st.button('Сделать гороскоп'):
|
| 196 |
+
start_time = time.time()
|
| 197 |
+
input_ids = tokenizer.encode(user_text_input, return_tensors="pt").to(device)
|
| 198 |
+
model.eval()
|
| 199 |
+
with torch.no_grad():
|
| 200 |
+
out = model.generate(input_ids,
|
| 201 |
+
do_sample=True,
|
| 202 |
+
num_beams=2,
|
| 203 |
+
temperature=1.1,
|
| 204 |
+
top_p=0.9,
|
| 205 |
+
max_length=50,
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
generated_text = list(map(tokenizer.decode, out))[0]
|
| 209 |
+
end_time = time.time()
|
| 210 |
+
prediction_time = end_time - start_time
|
| 211 |
+
|
| 212 |
+
st.write(f'Ваше предсказание: {generated_text}')
|