Spaces:
Runtime error
Runtime error
| import warnings | |
| warnings.filterwarnings("ignore") | |
| import logging | |
| logging.getLogger("streamlit").setLevel(logging.ERROR) | |
| import streamlit as st | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import torch | |
| def load_model(): | |
| model_name = "radlab/polish-gpt2-small-v2" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForCausalLM.from_pretrained(model_name) | |
| return tokenizer, model | |
| tokenizer, model = load_model() | |
| st.set_page_config(page_title="Polski Chatbot AI", page_icon="🤖") | |
| st.title("🤖 Polski Chatbot AI") | |
| st.caption("Model: radlab/polish-gpt2-small-v2") | |
| if "history" not in st.session_state: | |
| st.session_state.history = "" | |
| user_input = st.text_input("Wpisz wiadomość:", "") | |
| if st.button("Wyślij") and user_input.strip() != "": | |
| st.session_state.history += f"Użytkownik: {user_input}\nAI:" | |
| input_ids = tokenizer.encode(st.session_state.history, return_tensors="pt", truncation=True, max_length=1024) | |
| output = model.generate( | |
| input_ids, | |
| max_length=input_ids.shape[1] + 80, | |
| pad_token_id=tokenizer.eos_token_id, | |
| do_sample=True, | |
| top_k=50, | |
| top_p=0.95, | |
| temperature=0.7 | |
| ) | |
| output_text = tokenizer.decode(output[0], skip_special_tokens=True) | |
| model_reply = output_text[len(st.session_state.history):].split("Użytkownik:")[0].strip() | |
| st.session_state.history += f" {model_reply}\n" | |
| st.subheader("🗨️ Historia rozmów") | |
| st.text_area("📖", st.session_state.history.strip(), height=300) | |
| if st.button("🧹 Wyczyść historię"): | |
| st.session_state.history = "" | |