charesz's picture
Update app.py
ac8ee1c verified
import streamlit as st
from huggingface_hub import InferenceClient
import google_generativeai as genai
import time
import json # Import json for better handling of HF client response
# -------------------
# API Keys Setup
# -------------------
# Use Streamlit's built-in secrets handling
huggingface_token = st.secrets.get("HUGGINGFACE_HUB_TOKEN", "")
gemini_api_key = st.secrets.get("GEN_API_KEY", "")
# -------------------
# Configuration
# -------------------
st.set_page_config(page_title="Multi-Provider Chat", layout="wide")
st.title("⚡ Multi-Provider Chat App")
# List of recommended Hugging Face models that work well for chat via InferenceClient
# All instruction-tuned models (Mistral, Zephyr, Gemma) are failing due to server
# restrictions requiring the 'conversational' task or brittle templating.
# Switching to small, reliable base models guaranteed to support 'text-generation'.
HF_RECOMMENDED_MODELS = [
"gpt2", # New primary fallback: Very stable base model
"bigscience/bloom-560m", # Kept as secondary base model
]
# -------------------
# Sidebar Settings
# -------------------
st.sidebar.title("⚙️ Settings")
provider = st.sidebar.selectbox("Provider", ["Hugging Face", "Gemini"])
# -------------------
# Provider Setup
# -------------------
client = None
model = None
if provider == "Hugging Face":
if not huggingface_token:
st.error("⚠️ Please set your 'HUGGINGFACE_HUB_TOKEN' in Streamlit secrets.")
st.stop()
# Initialize the client
client = InferenceClient(token=huggingface_token)
selected_models = st.sidebar.multiselect(
"Choose HF models",
HF_RECOMMENDED_MODELS,
default=[HF_RECOMMENDED_MODELS[0]]
)
if not selected_models:
st.warning("⚠️ Please select at least one Hugging Face model.")
st.stop()
elif provider == "Gemini":
if not gemini_api_key:
st.error("⚠️ Please set your 'GEN_API_KEY' in Streamlit secrets.")
st.stop()
genai.configure(api_key=gemini_api_key)
# Fetch available models that support the generateContent method
available_models = [
m.name for m in genai.list_models() if "generateContent" in m.supported_generation_methods
]
if not available_models:
st.error("⚠️ No Gemini models available for your API key.")
st.stop()
model = st.sidebar.selectbox("Model", available_models)
# Initialize Gemini chat if model changes or if not initialized
if "gemini_chat" not in st.session_state or st.session_state.get("model") != model:
st.session_state.model = model
try:
gemini_model = genai.GenerativeModel(model)
st.session_state.gemini_chat = gemini_model.start_chat(history=[])
except Exception as e:
st.error(f"⚠️ Could not initialize Gemini model: {e}")
st.stop()
# -------------------
# System Prompt
# -------------------
system_prompt = st.sidebar.text_area(
"System Prompt",
"You are a helpful AI assistant. Provide concise and accurate answers."
)
# -------------------
# Chat History State
# -------------------
if "messages" not in st.session_state:
st.session_state.messages = []
# Reset conversation button
if st.sidebar.button("Reset Conversation"):
st.session_state.messages = []
# Also reset the Gemini chat history if using Gemini
if provider == "Gemini" and model:
gemini_model = genai.GenerativeModel(model)
st.session_state.gemini_chat = gemini_model.start_chat(history=[])
st.rerun() # Rerun to clear messages immediately
# -------------------
# Display Chat Messages
# -------------------
for msg in st.session_state.messages:
with st.chat_message(msg["role"]):
st.markdown(msg["content"])
# -------------------
# User Input
# -------------------
if user_input := st.chat_input("Type your message..."):
# 1. Display and save user message immediately
st.chat_message("user").markdown(user_input)
st.session_state.messages.append({"role": "user", "content": user_input})
# -------------------
# Hugging Face Logic
# -------------------
if provider == "Hugging Face":
for m in selected_models:
# Display a temporary "generating" message
with st.chat_message("assistant"):
message_placeholder = st.empty()
message_placeholder.markdown(f"**{m}** is generating...")
try:
bot_text = ""
# Use simple stop sequences for chat formatting, including "assistant:" itself
stop_sequences = ["assistant:", "user:"]
prompt_text = ""
# --- Generic Chat Template (Most reliable for text-generation endpoint) ---
# This uses the simple "role: content" format which is often robust.
conv = "\n".join([f"{msg['role']}: {msg['content']}" for msg in st.session_state.messages])
prompt_text = f"{system_prompt}\n\n{conv}\nassistant:"
# 2. Generate response using text_generation
resp = client.text_generation(
model=m,
prompt=prompt_text,
max_new_tokens=256,
temperature=0.7,
stop_sequences=stop_sequences
)
# 3. Unified parsing
if isinstance(resp, str):
bot_text = resp
elif isinstance(resp, dict) and "generated_text" in resp:
bot_text = resp["generated_text"]
elif isinstance(resp, list) and resp and "generated_text" in resp[0]:
bot_text = resp[0]["generated_text"]
# Clean up prompt from response if model echoes it (common behavior for text_generation)
if bot_text.startswith(prompt_text):
bot_text = bot_text[len(prompt_text):].strip()
except Exception as e:
# Catching connection errors or specific API deployment issues
bot_text = f"⚠️ Error with **{m}**: Model could not generate a response. ({type(e).__name__}: {e})"
# 4. Display and save final response (common logic for all models)
final_response = f"**{m}**\n\n{bot_text}"
# Update the temporary placeholder with the final response
message_placeholder.markdown(final_response)
# Save the final response to chat history
st.session_state.messages.append({"role": "assistant", "content": final_response})
st.rerun() # Rerun to update the display properly after generation
# -------------------
# Gemini Logic
# -------------------
elif provider == "Gemini":
try:
if user_input.strip():
with st.spinner("Gemini is thinking..."):
resp = st.session_state.gemini_chat.send_message(user_input)
bot_text = resp.text
else:
bot_text = "⚠️ Please enter a message before sending."
except Exception as e:
bot_text = f"⚠️ Gemini could not respond right now. Please try again. ({e})"
# Display and save assistant response
with st.chat_message("assistant"):
st.markdown(bot_text)
st.session_state.messages.append({"role": "assistant", "content": bot_text})
st.rerun()