File size: 3,685 Bytes
eaa1a45
98d4acd
 
135dce9
dc55ab4
f43cff4
 
dc55ab4
f43cff4
a93224f
 
f43cff4
 
a93224f
f43cff4
a93224f
 
 
f43cff4
dc55ab4
a93224f
135dce9
a93224f
 
98d4acd
a93224f
f43cff4
98d4acd
 
f43cff4
98d4acd
 
 
a93224f
f43cff4
a93224f
 
 
f43cff4
 
 
a93224f
98d4acd
135dce9
f43cff4
a93224f
135dce9
f43cff4
98d4acd
 
 
f43cff4
 
 
 
 
 
 
 
98d4acd
135dce9
 
 
f43cff4
a93224f
135dce9
 
 
 
 
f43cff4
 
 
98d4acd
 
f43cff4
98d4acd
 
 
 
 
 
f43cff4
98d4acd
 
 
 
a93224f
98d4acd
a93224f
98d4acd
 
 
 
 
 
f43cff4
98d4acd
f43cff4
98d4acd
f43cff4
98d4acd
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
import torch

# Настройка интерфейса
st.set_page_config(page_title="HiperAI Ultra Pro", page_icon="🏎️")

# Исправленный CSS (без ошибок в аргументах)
st.markdown("""
    <style>
    .stChatMessage { background-color: #1e2129 !important; border-radius: 10px; padding: 10px; margin-bottom: 5px; }
    .stChatInput { border-radius: 20px; }
    </style>
    """, unsafe_allow_html=True)

st.title("🏎️ HiperAI Ultra Speed")

# Загрузка модели и токенайзера
@st.cache_resource
def load_optimized_model():
    model_id = "Qwen/Qwen2.5-1.5B-Instruct"
    
    # Загружаем токенайзер (нужен sentencepiece)
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    
    # Загружаем модель с оптимизацией под CPU
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype=torch.float32, 
        low_cpu_mem_usage=True,
        device_map="cpu"
    )

    # Включаем Optimum BetterTransformer
    try:
        from optimum.bettertransformer import BetterTransformer
        model = BetterTransformer.transform(model)
        st.sidebar.success("🚀 Optimum Speedup: ON")
    except Exception:
        st.sidebar.info("Optimum: Normal Mode")
        
    return tokenizer, model

with st.spinner("Прогрев нейросети..."):
    tokenizer, model = load_optimized_model()

# Инициализация чата
if "messages" not in st.session_state:
    st.session_state.messages = []

# Боковая панель
with st.sidebar:
    st.title("⚙️ Настройки")
    if st.button("🗑️ Очистить историю"):
        st.session_state.messages = []
        st.rerun()

# Отображение последних сообщений
for message in st.session_state.messages[-10:]:
    with st.chat_message(message["role"]):
        st.markdown(message["content"])

# Логика ввода
if prompt := st.chat_input("Спроси HiperAI..."):
    st.session_state.messages.append({"role": "user", "content": prompt})
    with st.chat_message("user"):
        st.markdown(prompt)

    with st.chat_message("assistant"):
        # Подготовка контекста (системный промпт + последние 5 фраз)
        history = [{"role": "system", "content": "Ты HiperAI, отвечаешь быстро и на русском."}]
        history += st.session_state.messages[-5:]
        
        inputs = tokenizer.apply_chat_template(
            history, 
            add_generation_prompt=True, 
            return_tensors="pt"
        ).to(model.device)

        streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
        
        # Генерация с использованием режима инференса для скорости
        generation_kwargs = dict(
            input_ids=inputs,
            streamer=streamer,
            max_new_tokens=512,
            do_sample=True,
            temperature=0.7,
            use_cache=True
        )

        thread = Thread(target=model.generate, kwargs=generation_kwargs)
        thread.start()

        def stream_output():
            full_response = ""
            for new_text in streamer:
                full_response += new_text
                yield new_text
            st.session_state.messages.append({"role": "assistant", "content": full_response})

        st.write_stream(stream_output)