File size: 3,900 Bytes
a79ad35
bf233d5
a79ad35
 
 
 
 
 
 
9a022db
 
a79ad35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95


import torch
import time
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset

# Конфигурация моделей
MODEL_CONFIGS = {
    "GigaChat-like": "ai-forever/rugpt2large",
    "ChatGPT-like": "ai-forever/rugpt3large_based_on_gpt2",
    "DeepSeek-like": "ai-forever/rugpt3small_based_on_gpt2"
}

# Устройство
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Загрузка моделей
models = {}
for label, name in MODEL_CONFIGS.items():
    tokenizer = AutoTokenizer.from_pretrained(name)
    model = AutoModelForCausalLM.from_pretrained(name)
    model.to(device)
    model.eval()
    models[label] = (tokenizer, model)

# Загрузка датасета (не используется напрямую, но может быть полезен)
dataset = load_dataset("ZhenDOS/alpha_bank_data", split="train")

# CoT-промпты
def cot_prompt_1(text):
    return f"Клиент задал вопрос: {text}\nПодумай шаг за шагом и объясни, как бы ты ответил на это обращение от лица банка."

def cot_prompt_2(text):
    return f"Вопрос клиента: {text}\nРазложи на части, что именно спрашивает клиент, и предложи логичный ответ с пояснениями."

# Генерация
def generate_all_responses(question):
    results = {}
    for model_name, (tokenizer, model) in models.items():
        results[model_name] = {}
        for i, prompt_func in enumerate([cot_prompt_1, cot_prompt_2], start=1):
            prompt = prompt_func(question)
            inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
            inputs = {k: v.to(device) for k, v in inputs.items()}
            
            start_time = time.time()
            with torch.no_grad():
                outputs = model.generate(
                    **inputs,
                    max_new_tokens=200,
                    do_sample=True,
                    temperature=0.7,
                    top_p=0.9,
                    eos_token_id=tokenizer.eos_token_id
                )
            end_time = time.time()
            
            response = tokenizer.decode(outputs[0], skip_special_tokens=True)
            response = response.replace(prompt, "").strip()
            duration = round(end_time - start_time, 2)
            
            results[model_name][f"CoT Промпт {i}"] = {
                "response": response,
                "time": f"{duration} сек."
            }
    return results

# Отображение
def display_responses(question):
    all_responses = generate_all_responses(question)
    output = ""
    for model_name, prompts in all_responses.items():
        output += f"\n### Модель: {model_name}\n"
        for prompt_label, content in prompts.items():
            output += f"\n**{prompt_label}** ({content['time']}):\n{content['response']}\n"
    return output.strip()

# Интерфейс
demo = gr.Interface(
    fn=display_responses,
    inputs=gr.Textbox(lines=4, label="Введите клиентский вопрос"),
    outputs=gr.Markdown(label="Ответы от разных моделей"),
    title="Alpha Bank Assistant — сравнение моделей",
    description="Сравнение CoT-ответов от GigaChat, ChatGPT и DeepSeek-подобных моделей на обращение клиента.",
    examples=[
        "Как восстановить доступ в мобильный банк?",
        "Почему с меня списали комиссию за обслуживание карты?",
        "Какие условия по потребительскому кредиту?",
    ]
)

if __name__ == "__main__":
    demo.launch()