| | import streamlit as st |
| | from transformers import AutoModelForSeq2SeqLM, AutoTokenizer |
| | import torch |
| |
|
| | @st.cache_resource |
| | def load_model(): |
| | tokenizer = AutoTokenizer.from_pretrained("google/mt5-base", padding_side="left", use_fast=False) |
| | model = AutoModelForSeq2SeqLM.from_pretrained("google/mt5-base") |
| | return tokenizer, model |
| |
|
| | st.title("Український Чат-бот") |
| |
|
| | if "history" not in st.session_state: |
| | st.session_state.history = [] |
| |
|
| | if "user_input" not in st.session_state: |
| | st.session_state.user_input = "" |
| |
|
| | tokenizer, model = load_model() |
| |
|
| | def send_message(): |
| | if st.session_state.user_input: |
| | inputs = tokenizer(st.session_state.history + [st.session_state.user_input], return_tensors="pt", padding=True, truncation=True) |
| | with torch.no_grad(): |
| | outputs = model.generate(**inputs, max_length=100) |
| | response = tokenizer.decode(outputs[0], skip_special_tokens=True) |
| | st.session_state.history.extend([st.session_state.user_input, response]) |
| | st.session_state.user_input = "" |
| |
|
| | def update_user_input(): |
| | st.session_state.user_input = st.session_state.temp_user_input |
| |
|
| | st.text_input("Ви:", key="temp_user_input", on_change=update_user_input) |
| |
|
| | if st.button("Надіслати"): |
| | send_message() |
| |
|
| | |
| | if st.session_state.get("temp_user_input") and st.session_state.get("last_input", "") != st.session_state.get("temp_user_input"): |
| | st.session_state["last_input"] = st.session_state["temp_user_input"] |
| | send_message() |
| |
|
| | if st.session_state.history: |
| | for i in range(0, len(st.session_state.history), 2): |
| | st.write(f"Ви: {st.session_state.history[i]}") |
| | if i + 1 < len(st.session_state.history): |
| | st.write(f"Бот: {st.session_state.history[i+1]}") |