Files changed (1) hide show
  1. models/generate_text.py +48 -0
models/generate_text.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer
3
+ from safetensors import safe_open
4
+
5
+ # Функция для загрузки весов модели из файла safetensors
6
+ def load_model_weights(model, safetensors_path):
7
+ with safe_open(safetensors_path, framework="pt", device="cpu") as f:
8
+ for key in f.keys():
9
+ if key in model.state_dict():
10
+ try:
11
+ model.state_dict()[key].copy_(f.get_tensor(key))
12
+ except RuntimeError as e:
13
+ print(f"Error copying key {key}: {e}")
14
+ return model
15
+
16
+ # Загрузка токенизатора GPT-2
17
+ tokenizer = GPT2Tokenizer.from_pretrained("sberbank-ai/rugpt3small_based_on_gpt2")
18
+
19
+ # Добавление специального токена для заполнения
20
+ tokenizer.add_special_tokens({'pad_token': '[PAD]'})
21
+
22
+ # Загрузка модели GPT-2
23
+ model = GPT2LMHeadModel.from_pretrained("sberbank-ai/rugpt3small_based_on_gpt2")
24
+
25
+ # Изменение размера токенов в модели после добавления специального токена
26
+ model.resize_token_embeddings(len(tokenizer))
27
+
28
+ # Загрузка весов из safetensors
29
+ model = load_model_weights(model, "models/model_lenin_zametki.safetensors")
30
+
31
+ # Streamlit приложение
32
+ def generate_text(prompt, length, num_generations, temperature, top_k, top_p):
33
+ inputs = tokenizer.encode(prompt, return_tensors="pt")
34
+ outputs = []
35
+
36
+ for _ in range(num_generations):
37
+ output = model.generate(
38
+ inputs,
39
+ max_length=length,
40
+ temperature=temperature,
41
+ top_k=top_k,
42
+ top_p=top_p,
43
+ num_return_sequences=1
44
+ )
45
+ text = tokenizer.decode(output[0], skip_special_tokens=True)
46
+ outputs.append(text)
47
+
48
+ return outputs