| import streamlit as st |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| import torch |
|
|
| st.set_page_config(page_title="ИТМО Магистратура Чат-бот", page_icon="🎓") |
| st.title("🎓 Чат-бот про магистратуру ИТМО") |
|
|
| MODEL_NAME = "sberbank-ai/rugpt3small_based_on_gpt2" |
|
|
|
|
| @st.cache_resource |
| def load_model(): |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
| model = AutoModelForCausalLM.from_pretrained(MODEL_NAME) |
| if torch.cuda.is_available(): |
| model = model.to('cuda') |
| return tokenizer, model |
|
|
| tokenizer, model = load_model() |
|
|
| if "history" not in st.session_state: |
| st.session_state.history = [] |
|
|
| SYSTEM_PROMPT = """Вы являетесь виртуальным помощником для абитуриентов магистратуры Университета ИТМО. Отвечаете на вопросы о магистерских программах ИТМО.""" |
|
|
| user_input = st.text_input("Введите ваш вопрос про магистратуру ИТМО:") |
|
|
| if user_input: |
| input_text = SYSTEM_PROMPT + "\n" + user_input |
| inputs = tokenizer(input_text, return_tensors="pt") |
|
|
| if torch.cuda.is_available(): |
| inputs = {k: v.to('cuda') for k, v in inputs.items()} |
|
|
| outputs = model.generate(**inputs, max_length=500, do_sample=True, temperature=0.7, pad_token_id=tokenizer.eos_token_id) |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) |
| reply = response[len(input_text):].strip() |
|
|
| st.session_state.history.append((user_input, reply)) |
|
|
| for i, (q, a) in enumerate(st.session_state.history): |
| st.markdown(f"**Вы:** {q}") |
| st.markdown(f"**Бот:** {a}") |