Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from huggingface_hub import InferenceClient | |
| import google_generativeai as genai | |
| import time | |
| import json # Import json for better handling of HF client response | |
| # ------------------- | |
| # API Keys Setup | |
| # ------------------- | |
| # Use Streamlit's built-in secrets handling | |
| huggingface_token = st.secrets.get("HUGGINGFACE_HUB_TOKEN", "") | |
| gemini_api_key = st.secrets.get("GEN_API_KEY", "") | |
| # ------------------- | |
| # Configuration | |
| # ------------------- | |
| st.set_page_config(page_title="Multi-Provider Chat", layout="wide") | |
| st.title("⚡ Multi-Provider Chat App") | |
| # List of recommended Hugging Face models that work well for chat via InferenceClient | |
| # All instruction-tuned models (Mistral, Zephyr, Gemma) are failing due to server | |
| # restrictions requiring the 'conversational' task or brittle templating. | |
| # Switching to small, reliable base models guaranteed to support 'text-generation'. | |
| HF_RECOMMENDED_MODELS = [ | |
| "gpt2", # New primary fallback: Very stable base model | |
| "bigscience/bloom-560m", # Kept as secondary base model | |
| ] | |
| # ------------------- | |
| # Sidebar Settings | |
| # ------------------- | |
| st.sidebar.title("⚙️ Settings") | |
| provider = st.sidebar.selectbox("Provider", ["Hugging Face", "Gemini"]) | |
| # ------------------- | |
| # Provider Setup | |
| # ------------------- | |
| client = None | |
| model = None | |
| if provider == "Hugging Face": | |
| if not huggingface_token: | |
| st.error("⚠️ Please set your 'HUGGINGFACE_HUB_TOKEN' in Streamlit secrets.") | |
| st.stop() | |
| # Initialize the client | |
| client = InferenceClient(token=huggingface_token) | |
| selected_models = st.sidebar.multiselect( | |
| "Choose HF models", | |
| HF_RECOMMENDED_MODELS, | |
| default=[HF_RECOMMENDED_MODELS[0]] | |
| ) | |
| if not selected_models: | |
| st.warning("⚠️ Please select at least one Hugging Face model.") | |
| st.stop() | |
| elif provider == "Gemini": | |
| if not gemini_api_key: | |
| st.error("⚠️ Please set your 'GEN_API_KEY' in Streamlit secrets.") | |
| st.stop() | |
| genai.configure(api_key=gemini_api_key) | |
| # Fetch available models that support the generateContent method | |
| available_models = [ | |
| m.name for m in genai.list_models() if "generateContent" in m.supported_generation_methods | |
| ] | |
| if not available_models: | |
| st.error("⚠️ No Gemini models available for your API key.") | |
| st.stop() | |
| model = st.sidebar.selectbox("Model", available_models) | |
| # Initialize Gemini chat if model changes or if not initialized | |
| if "gemini_chat" not in st.session_state or st.session_state.get("model") != model: | |
| st.session_state.model = model | |
| try: | |
| gemini_model = genai.GenerativeModel(model) | |
| st.session_state.gemini_chat = gemini_model.start_chat(history=[]) | |
| except Exception as e: | |
| st.error(f"⚠️ Could not initialize Gemini model: {e}") | |
| st.stop() | |
| # ------------------- | |
| # System Prompt | |
| # ------------------- | |
| system_prompt = st.sidebar.text_area( | |
| "System Prompt", | |
| "You are a helpful AI assistant. Provide concise and accurate answers." | |
| ) | |
| # ------------------- | |
| # Chat History State | |
| # ------------------- | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| # Reset conversation button | |
| if st.sidebar.button("Reset Conversation"): | |
| st.session_state.messages = [] | |
| # Also reset the Gemini chat history if using Gemini | |
| if provider == "Gemini" and model: | |
| gemini_model = genai.GenerativeModel(model) | |
| st.session_state.gemini_chat = gemini_model.start_chat(history=[]) | |
| st.rerun() # Rerun to clear messages immediately | |
| # ------------------- | |
| # Display Chat Messages | |
| # ------------------- | |
| for msg in st.session_state.messages: | |
| with st.chat_message(msg["role"]): | |
| st.markdown(msg["content"]) | |
| # ------------------- | |
| # User Input | |
| # ------------------- | |
| if user_input := st.chat_input("Type your message..."): | |
| # 1. Display and save user message immediately | |
| st.chat_message("user").markdown(user_input) | |
| st.session_state.messages.append({"role": "user", "content": user_input}) | |
| # ------------------- | |
| # Hugging Face Logic | |
| # ------------------- | |
| if provider == "Hugging Face": | |
| for m in selected_models: | |
| # Display a temporary "generating" message | |
| with st.chat_message("assistant"): | |
| message_placeholder = st.empty() | |
| message_placeholder.markdown(f"**{m}** is generating...") | |
| try: | |
| bot_text = "" | |
| # Use simple stop sequences for chat formatting, including "assistant:" itself | |
| stop_sequences = ["assistant:", "user:"] | |
| prompt_text = "" | |
| # --- Generic Chat Template (Most reliable for text-generation endpoint) --- | |
| # This uses the simple "role: content" format which is often robust. | |
| conv = "\n".join([f"{msg['role']}: {msg['content']}" for msg in st.session_state.messages]) | |
| prompt_text = f"{system_prompt}\n\n{conv}\nassistant:" | |
| # 2. Generate response using text_generation | |
| resp = client.text_generation( | |
| model=m, | |
| prompt=prompt_text, | |
| max_new_tokens=256, | |
| temperature=0.7, | |
| stop_sequences=stop_sequences | |
| ) | |
| # 3. Unified parsing | |
| if isinstance(resp, str): | |
| bot_text = resp | |
| elif isinstance(resp, dict) and "generated_text" in resp: | |
| bot_text = resp["generated_text"] | |
| elif isinstance(resp, list) and resp and "generated_text" in resp[0]: | |
| bot_text = resp[0]["generated_text"] | |
| # Clean up prompt from response if model echoes it (common behavior for text_generation) | |
| if bot_text.startswith(prompt_text): | |
| bot_text = bot_text[len(prompt_text):].strip() | |
| except Exception as e: | |
| # Catching connection errors or specific API deployment issues | |
| bot_text = f"⚠️ Error with **{m}**: Model could not generate a response. ({type(e).__name__}: {e})" | |
| # 4. Display and save final response (common logic for all models) | |
| final_response = f"**{m}**\n\n{bot_text}" | |
| # Update the temporary placeholder with the final response | |
| message_placeholder.markdown(final_response) | |
| # Save the final response to chat history | |
| st.session_state.messages.append({"role": "assistant", "content": final_response}) | |
| st.rerun() # Rerun to update the display properly after generation | |
| # ------------------- | |
| # Gemini Logic | |
| # ------------------- | |
| elif provider == "Gemini": | |
| try: | |
| if user_input.strip(): | |
| with st.spinner("Gemini is thinking..."): | |
| resp = st.session_state.gemini_chat.send_message(user_input) | |
| bot_text = resp.text | |
| else: | |
| bot_text = "⚠️ Please enter a message before sending." | |
| except Exception as e: | |
| bot_text = f"⚠️ Gemini could not respond right now. Please try again. ({e})" | |
| # Display and save assistant response | |
| with st.chat_message("assistant"): | |
| st.markdown(bot_text) | |
| st.session_state.messages.append({"role": "assistant", "content": bot_text}) | |
| st.rerun() | |