Spaces:
Sleeping
Sleeping
File size: 7,679 Bytes
40bd192 ac8ee1c 40bd192 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 |
import streamlit as st
from huggingface_hub import InferenceClient
import google_generativeai as genai
import time
import json # Import json for better handling of HF client response
# -------------------
# API Keys Setup
# -------------------
# Use Streamlit's built-in secrets handling
huggingface_token = st.secrets.get("HUGGINGFACE_HUB_TOKEN", "")
gemini_api_key = st.secrets.get("GEN_API_KEY", "")
# -------------------
# Configuration
# -------------------
st.set_page_config(page_title="Multi-Provider Chat", layout="wide")
st.title("⚡ Multi-Provider Chat App")
# List of recommended Hugging Face models that work well for chat via InferenceClient
# All instruction-tuned models (Mistral, Zephyr, Gemma) are failing due to server
# restrictions requiring the 'conversational' task or brittle templating.
# Switching to small, reliable base models guaranteed to support 'text-generation'.
HF_RECOMMENDED_MODELS = [
"gpt2", # New primary fallback: Very stable base model
"bigscience/bloom-560m", # Kept as secondary base model
]
# -------------------
# Sidebar Settings
# -------------------
st.sidebar.title("⚙️ Settings")
provider = st.sidebar.selectbox("Provider", ["Hugging Face", "Gemini"])
# -------------------
# Provider Setup
# -------------------
client = None
model = None
if provider == "Hugging Face":
if not huggingface_token:
st.error("⚠️ Please set your 'HUGGINGFACE_HUB_TOKEN' in Streamlit secrets.")
st.stop()
# Initialize the client
client = InferenceClient(token=huggingface_token)
selected_models = st.sidebar.multiselect(
"Choose HF models",
HF_RECOMMENDED_MODELS,
default=[HF_RECOMMENDED_MODELS[0]]
)
if not selected_models:
st.warning("⚠️ Please select at least one Hugging Face model.")
st.stop()
elif provider == "Gemini":
if not gemini_api_key:
st.error("⚠️ Please set your 'GEN_API_KEY' in Streamlit secrets.")
st.stop()
genai.configure(api_key=gemini_api_key)
# Fetch available models that support the generateContent method
available_models = [
m.name for m in genai.list_models() if "generateContent" in m.supported_generation_methods
]
if not available_models:
st.error("⚠️ No Gemini models available for your API key.")
st.stop()
model = st.sidebar.selectbox("Model", available_models)
# Initialize Gemini chat if model changes or if not initialized
if "gemini_chat" not in st.session_state or st.session_state.get("model") != model:
st.session_state.model = model
try:
gemini_model = genai.GenerativeModel(model)
st.session_state.gemini_chat = gemini_model.start_chat(history=[])
except Exception as e:
st.error(f"⚠️ Could not initialize Gemini model: {e}")
st.stop()
# -------------------
# System Prompt
# -------------------
system_prompt = st.sidebar.text_area(
"System Prompt",
"You are a helpful AI assistant. Provide concise and accurate answers."
)
# -------------------
# Chat History State
# -------------------
if "messages" not in st.session_state:
st.session_state.messages = []
# Reset conversation button
if st.sidebar.button("Reset Conversation"):
st.session_state.messages = []
# Also reset the Gemini chat history if using Gemini
if provider == "Gemini" and model:
gemini_model = genai.GenerativeModel(model)
st.session_state.gemini_chat = gemini_model.start_chat(history=[])
st.rerun() # Rerun to clear messages immediately
# -------------------
# Display Chat Messages
# -------------------
for msg in st.session_state.messages:
with st.chat_message(msg["role"]):
st.markdown(msg["content"])
# -------------------
# User Input
# -------------------
if user_input := st.chat_input("Type your message..."):
# 1. Display and save user message immediately
st.chat_message("user").markdown(user_input)
st.session_state.messages.append({"role": "user", "content": user_input})
# -------------------
# Hugging Face Logic
# -------------------
if provider == "Hugging Face":
for m in selected_models:
# Display a temporary "generating" message
with st.chat_message("assistant"):
message_placeholder = st.empty()
message_placeholder.markdown(f"**{m}** is generating...")
try:
bot_text = ""
# Use simple stop sequences for chat formatting, including "assistant:" itself
stop_sequences = ["assistant:", "user:"]
prompt_text = ""
# --- Generic Chat Template (Most reliable for text-generation endpoint) ---
# This uses the simple "role: content" format which is often robust.
conv = "\n".join([f"{msg['role']}: {msg['content']}" for msg in st.session_state.messages])
prompt_text = f"{system_prompt}\n\n{conv}\nassistant:"
# 2. Generate response using text_generation
resp = client.text_generation(
model=m,
prompt=prompt_text,
max_new_tokens=256,
temperature=0.7,
stop_sequences=stop_sequences
)
# 3. Unified parsing
if isinstance(resp, str):
bot_text = resp
elif isinstance(resp, dict) and "generated_text" in resp:
bot_text = resp["generated_text"]
elif isinstance(resp, list) and resp and "generated_text" in resp[0]:
bot_text = resp[0]["generated_text"]
# Clean up prompt from response if model echoes it (common behavior for text_generation)
if bot_text.startswith(prompt_text):
bot_text = bot_text[len(prompt_text):].strip()
except Exception as e:
# Catching connection errors or specific API deployment issues
bot_text = f"⚠️ Error with **{m}**: Model could not generate a response. ({type(e).__name__}: {e})"
# 4. Display and save final response (common logic for all models)
final_response = f"**{m}**\n\n{bot_text}"
# Update the temporary placeholder with the final response
message_placeholder.markdown(final_response)
# Save the final response to chat history
st.session_state.messages.append({"role": "assistant", "content": final_response})
st.rerun() # Rerun to update the display properly after generation
# -------------------
# Gemini Logic
# -------------------
elif provider == "Gemini":
try:
if user_input.strip():
with st.spinner("Gemini is thinking..."):
resp = st.session_state.gemini_chat.send_message(user_input)
bot_text = resp.text
else:
bot_text = "⚠️ Please enter a message before sending."
except Exception as e:
bot_text = f"⚠️ Gemini could not respond right now. Please try again. ({e})"
# Display and save assistant response
with st.chat_message("assistant"):
st.markdown(bot_text)
st.session_state.messages.append({"role": "assistant", "content": bot_text})
st.rerun()
|