File size: 8,460 Bytes
7136781
d53a92a
 
 
 
7136781
d53a92a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7136781
d53a92a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7136781
d53a92a
 
 
 
 
7136781
d53a92a
 
 
 
 
 
7136781
d53a92a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7136781
d53a92a
 
 
 
 
 
 
 
 
 
 
 
 
7136781
d53a92a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7136781
d53a92a
7136781
d53a92a
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import gradio as gr
from decoderOnly import TransformerRun
from transformers import AutoTokenizer
import torch
import os

class ChatBot:
    def __init__(self, model_path="."):
        """
        Инициализация бота.
        В Space файлы модели должны находиться в корневой директории.
        """
        print(f"Загрузка модели из: {model_path}")
        
        try:
            # Загружаем токенизатор
            self.tokenizer = AutoTokenizer.from_pretrained(model_path)
            print("Токенизатор загружен успешно.")
            
            # Если у токенизатора нет pad_token, устанавливаем его
            if self.tokenizer.pad_token is None:
                self.tokenizer.pad_token = self.tokenizer.eos_token if self.tokenizer.eos_token else "[PAD]"
                print(f"Установлен pad_token: {self.tokenizer.pad_token}")
            
            # Создаем модель с параметрами токенизатора
            self.model = TransformerRun(
                vocabSize=len(self.tokenizer), 
                maxLong=256,
                sizeVector=128,
                block=2
            )
            
            # Загружаем веса модели (в Space файл будет в корне)
            weights_path = f"{model_path}/model_weights.pth"
            if not os.path.exists(weights_path):
                # Пробуем найти веса без подпапки
                weights_path = "model_weights.pth"
                
            print(f"Загрузка весов из: {weights_path}")
            self.model.load_state_dict(
                torch.load(weights_path, map_location='cpu', weights_only=True)
            )
            
            # Настраиваем устройство
            self.device = torch.device("cpu")
            self.model.to(self.device)
            self.model.eval()
            print("Модель загружена и готова к работе!")
            
        except Exception as e:
            print(f"Ошибка при загрузке модели: {e}")
            raise

    def generate(self, prompt, max_length=100, temperature=0.5, top_k=50):
        """
        Генерация ответа на промпт пользователя.
        """
        try:
            if not prompt or prompt.strip() == "":
                return "Пожалуйста, введите сообщение."
                
            print(f"Генерация ответа для промпта: '{prompt[:50]}...'")
            
            # Токенизируем промпт
            inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=200)
            input_ids = inputs["input_ids"].to(self.device)
            
            # Если последовательность пустая после токенизации
            if input_ids.size(1) == 0:
                return "Не удалось обработать запрос. Попробуйте другие слова."
                
            generated_ids = input_ids.clone()

            with torch.no_grad():
                for _ in range(max_length):
                    # Прямой проход модели
                    outputs = self.model(generated_ids)
                    logits = outputs[0, -1, :] / temperature  # учитываем температуру

                    # Top-k sampling
                    if top_k > 0:
                        topk_values, topk_indices = torch.topk(logits, min(top_k, logits.size(-1)))
                        probs = torch.zeros_like(logits).scatter(0, topk_indices, torch.softmax(topk_values, dim=-1))
                    else:
                        probs = torch.softmax(logits, dim=-1)

                    # Выбираем следующий токен
                    next_token = torch.multinomial(probs, num_samples=1)
                    
                    # Добавляем к сгенерированной последовательности
                    generated_ids = torch.cat([generated_ids, next_token.unsqueeze(0)], dim=1)
                    
                    # Останавливаемся на EOS или PAD
                    stop_tokens = []
                    if self.tokenizer.eos_token_id is not None:
                        stop_tokens.append(self.tokenizer.eos_token_id)
                    if self.tokenizer.pad_token_id is not None:
                        stop_tokens.append(self.tokenizer.pad_token_id)
                        
                    if next_token.item() in stop_tokens:
                        print(f"Остановка на токене: {next_token.item()}")
                        break

            # Декодируем обратно в текст
            response = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
            
            # Убираем оригинальный промпт из ответа
            if response.startswith(prompt):
                response = response[len(prompt):].strip()
                
            print(f"Сгенерирован ответ длиной {len(response)} символов.")
            return response
            
        except Exception as e:
            print(f"Ошибка при генерации: {e}")
            return f"Произошла ошибка: {str(e)}"

def create_interface():
    """
    Создание Gradio интерфейса.
    """
    try:
        # Инициализируем бота
        # В Space модель будет находиться в корневой директории
        bot = ChatBot(model_path=".")
        print("Интерфейс запускается...")
        
        def respond(message, history):
            """
            Функция для обработки сообщений в интерфейсе чата.
            """
            # history содержит предыдущие сообщения в формате [[user1, bot1], [user2, bot2], ...]
            # Мы будем генерировать ответ только на последнее сообщение
            response = bot.generate(
                prompt=message,
                max_length=100,
                temperature=0.7,
                top_k=50
            )
            return response
        
        # Создаем интерфейс чата
        demo = gr.ChatInterface(
            fn=respond,
            title="BasicSmall ChatBot",
            description="Демонстрация модели MarkProMaster229/BasicSmall. Напишите сообщение и нажмите Submit.",
            examples=["Привет!", "Расскажи что-нибудь интересное", "Как дела?"],
            theme="soft"
        )
        
        return demo
        
    except Exception as e:
        print(f"Критическая ошибка при создании интерфейса: {e}")
        
        # Создаем простой интерфейс с сообщением об ошибке
        def error_response(message, history):
            return f"Модель не загружена. Ошибка: {str(e)}"
            
        return gr.ChatInterface(
            fn=error_response,
            title="BasicSmall ChatBot (Ошибка)",
            description="Не удалось загрузить модель. Проверьте файлы модели."
        )

# Создаем и запускаем интерфейс
if __name__ == "__main__":
    # Устанавливаем уровень логирования
    import logging
    logging.basicConfig(level=logging.INFO)
    
    # Создаем интерфейс
    demo = create_interface()
    
    # Запускаем с параметрами для Hugging Face Spaces
    demo.launch(
        server_name="0.0.0.0",  # Обязательно для Spaces
        server_port=7860,        # Стандартный порт для Spaces
        share=False              # Не создавать публичную ссылку (в Spaces это не нужно)
    )