Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from langchain_huggingface import HuggingFaceEndpoint | |
| from langchain_core.prompts import PromptTemplate | |
| import os | |
| # Set up your Hugging Face API token | |
| os.environ["HUGGINGFACEHUB_API_TOKEN"] = st.secrets["HF_TOKEN"] | |
| # Define the models | |
| models = { | |
| "Mistral-7B-Instruct-v0.2": "mistralai/Mistral-7B-Instruct-v0.2", | |
| "Mistral-7B-Instruct-v0.3": "mistralai/Mistral-7B-Instruct-v0.3", | |
| "GPT-2": "gpt2", | |
| "BLOOM": "bigscience/bloom", | |
| "OPT": "facebook/opt-350m" | |
| } | |
| # Initialize session state | |
| if 'messages' not in st.session_state: | |
| st.session_state.messages = [] | |
| # Streamlit app | |
| st.title("Multi-Model LLM Chat") | |
| # Model selection | |
| selected_model = st.selectbox("Choose a model", list(models.keys())) | |
| # User input | |
| user_input = st.text_input("Your message:") | |
| # Initialize LLM | |
| def get_llm(model_name): | |
| return HuggingFaceEndpoint( | |
| repo_id=models[model_name], | |
| max_length=128, | |
| temperature=0.7 | |
| ) | |
| llm = get_llm(selected_model) | |
| # Chat prompt template | |
| prompt = PromptTemplate( | |
| template="Human: {human_input}\n\nAssistant: Let's think about this step-by-step:", | |
| input_variables=["human_input"] | |
| ) | |
| # Generate response | |
| if user_input: | |
| # Add user message to chat history | |
| st.session_state.messages.append({"role": "user", "content": user_input}) | |
| # Generate LLM response | |
| with st.spinner("Generating response..."): | |
| full_prompt = prompt.format(human_input=user_input) | |
| response = llm.invoke(full_prompt) | |
| # Add assistant response to chat history | |
| st.session_state.messages.append({"role": "assistant", "content": response}) | |
| # Display chat history | |
| for message in st.session_state.messages: | |
| with st.chat_message(message["role"]): | |
| st.write(message["content"]) | |
| # Clear chat button | |
| if st.button("Clear Chat"): | |
| st.session_state.messages = [] | |
| st.experimental_rerun() | |