LLM-Powered / llm_handler.py
forzen's picture
Upload 11 files
634b5dc verified
# llm_handler.py
import google.generativeai as genai
from config import GOOGLE_API_KEY, GENERATIVE_MODEL, EMBEDDING_MODEL
import streamlit as st # For displaying errors or warnings if needed
# Configure Gemini API
if GOOGLE_API_KEY:
try:
genai.configure(api_key=GOOGLE_API_KEY)
except Exception as e:
st.error(f"Failed to configure Gemini API: {e}") # Show error in Streamlit if app is running
print(f"Failed to configure Gemini API: {e}") # Print to console for server logs
else:
# This will be handled by Streamlit UI in app.py if key is missing
print("Warning: GOOGLE_API_KEY is not set. LLM features will not work.")
def get_gemini_response(prompt_text, system_instruction=None):
"""获取Gemini模型的响应"""
if not GOOGLE_API_KEY:
st.error("Gemini API Key未配置,无法获取模型响应。请在Hugging Face Space Secrets中设置 GOOGLE_API_KEY。")
return None
try:
model = genai.GenerativeModel(
GENERATIVE_MODEL,
system_instruction=system_instruction if system_instruction else None
)
response = model.generate_content(prompt_text)
return response.text
except Exception as e:
error_message = f"与Gemini通信时出错: {e}"
if hasattr(e, 'message') and "API key not valid" in e.message:
error_message = "Gemini API Key无效或权限不足。请检查Hugging Face Space Secrets中的GOOGLE_API_KEY。"
st.error(error_message)
print(error_message) # For server logs
return None
# Using genai.embed_content directly is often simpler for ChromaDB
# but if you need a callable for ChromaDB's embedding_functions parameter:
class GeminiEmbeddingFunctionForChroma(genai.embedding.EmbeddingFunction):
def __call__(self, input: genai.embedding.EmbedContentRequest) -> genai.embedding.EmbedContentResponse:
# Ensure 'input' is a list of strings (documents)
if not isinstance(input, list) or not all(isinstance(doc, str) for doc in input):
# ChromaDB typically passes a list of documents (strings)
# genai.embed_content expects a 'content' field which can be a string or list of strings
# The structure of 'input' from ChromaDB needs to be correctly mapped.
# ChromaDB's `embedding_function` interface expects a function that takes a list of texts
# and returns a list of embeddings.
# Let's assume 'input' is a list of document strings.
docs_to_embed = input
else: # Fallback if input structure is different, adapt as needed
docs_to_embed = [str(item) for item in input]
if not docs_to_embed:
return {"embedding": []} # Return empty embedding list for empty input
try:
# Embed a batch of documents.
# `task_type` is important for retrieval.
result = genai.embed_content(
model=EMBEDDING_MODEL,
content=docs_to_embed,
task_type="RETRIEVAL_DOCUMENT"
)
return result['embedding'] # ChromaDB expects a list of embeddings
except Exception as e:
error_message = f"获取文本嵌入时出错: {e}"
st.error(error_message)
print(error_message)
# Return a list of Nones or empty lists of the correct length if an error occurs for some documents
return [None] * len(docs_to_embed)
# --- Alternative simpler embedding function for ChromaDB ---
# This is often easier to integrate if ChromaDB's embedding_function
# parameter expects a function that takes a list of texts.
from chromadb import Documents, EmbeddingFunction, Embeddings
class GeminiChromaEF(EmbeddingFunction):
def __init__(self, model_name: str = EMBEDDING_MODEL, task_type: str = "RETRIEVAL_DOCUMENT"):
self._model_name = model_name
self._task_type = task_type
if not GOOGLE_API_KEY:
print("Warning: GOOGLE_API_KEY not set. Embedding function might fail.")
# Optionally raise an error or handle appropriately
def __call__(self, input_texts: Documents) -> Embeddings:
if not GOOGLE_API_KEY:
st.error("Gemini API Key未配置,无法生成文本嵌入。")
print("Gemini API Key not configured for embeddings.")
return [([0.0] * 768) for _ in input_texts] # Return dummy embeddings or handle error
if not input_texts:
return []
try:
# Filter out any None or non-string inputs, though Documents type should be list of str
valid_texts = [text for text in input_texts if isinstance(text, str)]
if not valid_texts:
# Handle case where all inputs were invalid
return [([0.0] * 768) for _ in input_texts]
result = genai.embed_content(
model=self._model_name,
content=valid_texts,
task_type=self._task_type
)
# Ensure the result matches the number of valid_texts.
# If there was an error, result['embedding'] might be shorter or None.
# A robust handler would map results back to original input count, perhaps with None for errors.
# For simplicity here, assuming success or a catastrophic failure handled by the try-except.
# Map embeddings back to the original input_texts length, filling with None for invalid ones
# This part is tricky because genai.embed_content might error out entirely or skip bad inputs.
# Let's assume it returns embeddings for valid_texts only.
embeddings_dict = {text: emb for text, emb in zip(valid_texts, result['embedding'])}
final_embeddings = []
for text in input_texts:
if isinstance(text, str) and text in embeddings_dict:
final_embeddings.append(embeddings_dict[text])
else:
# Provide a dummy embedding or None for invalid/missing inputs
# The dimension (e.g., 768) depends on your embedding model.
# For "models/embedding-001", it's 768.
final_embeddings.append([0.0] * 768) # Placeholder for invalid inputs
return final_embeddings
except Exception as e:
error_message = f"获取文本嵌入时出错 (GeminiChromaEF): {e}"
st.error(error_message)
print(error_message)
# Return dummy embeddings for all inputs in case of a general error
return [[0.0] * 768 for _ in input_texts] # Placeholder dimension