Spaces:
Paused
Paused
File size: 6,882 Bytes
634b5dc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
# llm_handler.py
import google.generativeai as genai
from config import GOOGLE_API_KEY, GENERATIVE_MODEL, EMBEDDING_MODEL
import streamlit as st # For displaying errors or warnings if needed
# Configure Gemini API
if GOOGLE_API_KEY:
try:
genai.configure(api_key=GOOGLE_API_KEY)
except Exception as e:
st.error(f"Failed to configure Gemini API: {e}") # Show error in Streamlit if app is running
print(f"Failed to configure Gemini API: {e}") # Print to console for server logs
else:
# This will be handled by Streamlit UI in app.py if key is missing
print("Warning: GOOGLE_API_KEY is not set. LLM features will not work.")
def get_gemini_response(prompt_text, system_instruction=None):
"""获取Gemini模型的响应"""
if not GOOGLE_API_KEY:
st.error("Gemini API Key未配置,无法获取模型响应。请在Hugging Face Space Secrets中设置 GOOGLE_API_KEY。")
return None
try:
model = genai.GenerativeModel(
GENERATIVE_MODEL,
system_instruction=system_instruction if system_instruction else None
)
response = model.generate_content(prompt_text)
return response.text
except Exception as e:
error_message = f"与Gemini通信时出错: {e}"
if hasattr(e, 'message') and "API key not valid" in e.message:
error_message = "Gemini API Key无效或权限不足。请检查Hugging Face Space Secrets中的GOOGLE_API_KEY。"
st.error(error_message)
print(error_message) # For server logs
return None
# Using genai.embed_content directly is often simpler for ChromaDB
# but if you need a callable for ChromaDB's embedding_functions parameter:
class GeminiEmbeddingFunctionForChroma(genai.embedding.EmbeddingFunction):
def __call__(self, input: genai.embedding.EmbedContentRequest) -> genai.embedding.EmbedContentResponse:
# Ensure 'input' is a list of strings (documents)
if not isinstance(input, list) or not all(isinstance(doc, str) for doc in input):
# ChromaDB typically passes a list of documents (strings)
# genai.embed_content expects a 'content' field which can be a string or list of strings
# The structure of 'input' from ChromaDB needs to be correctly mapped.
# ChromaDB's `embedding_function` interface expects a function that takes a list of texts
# and returns a list of embeddings.
# Let's assume 'input' is a list of document strings.
docs_to_embed = input
else: # Fallback if input structure is different, adapt as needed
docs_to_embed = [str(item) for item in input]
if not docs_to_embed:
return {"embedding": []} # Return empty embedding list for empty input
try:
# Embed a batch of documents.
# `task_type` is important for retrieval.
result = genai.embed_content(
model=EMBEDDING_MODEL,
content=docs_to_embed,
task_type="RETRIEVAL_DOCUMENT"
)
return result['embedding'] # ChromaDB expects a list of embeddings
except Exception as e:
error_message = f"获取文本嵌入时出错: {e}"
st.error(error_message)
print(error_message)
# Return a list of Nones or empty lists of the correct length if an error occurs for some documents
return [None] * len(docs_to_embed)
# --- Alternative simpler embedding function for ChromaDB ---
# This is often easier to integrate if ChromaDB's embedding_function
# parameter expects a function that takes a list of texts.
from chromadb import Documents, EmbeddingFunction, Embeddings
class GeminiChromaEF(EmbeddingFunction):
def __init__(self, model_name: str = EMBEDDING_MODEL, task_type: str = "RETRIEVAL_DOCUMENT"):
self._model_name = model_name
self._task_type = task_type
if not GOOGLE_API_KEY:
print("Warning: GOOGLE_API_KEY not set. Embedding function might fail.")
# Optionally raise an error or handle appropriately
def __call__(self, input_texts: Documents) -> Embeddings:
if not GOOGLE_API_KEY:
st.error("Gemini API Key未配置,无法生成文本嵌入。")
print("Gemini API Key not configured for embeddings.")
return [([0.0] * 768) for _ in input_texts] # Return dummy embeddings or handle error
if not input_texts:
return []
try:
# Filter out any None or non-string inputs, though Documents type should be list of str
valid_texts = [text for text in input_texts if isinstance(text, str)]
if not valid_texts:
# Handle case where all inputs were invalid
return [([0.0] * 768) for _ in input_texts]
result = genai.embed_content(
model=self._model_name,
content=valid_texts,
task_type=self._task_type
)
# Ensure the result matches the number of valid_texts.
# If there was an error, result['embedding'] might be shorter or None.
# A robust handler would map results back to original input count, perhaps with None for errors.
# For simplicity here, assuming success or a catastrophic failure handled by the try-except.
# Map embeddings back to the original input_texts length, filling with None for invalid ones
# This part is tricky because genai.embed_content might error out entirely or skip bad inputs.
# Let's assume it returns embeddings for valid_texts only.
embeddings_dict = {text: emb for text, emb in zip(valid_texts, result['embedding'])}
final_embeddings = []
for text in input_texts:
if isinstance(text, str) and text in embeddings_dict:
final_embeddings.append(embeddings_dict[text])
else:
# Provide a dummy embedding or None for invalid/missing inputs
# The dimension (e.g., 768) depends on your embedding model.
# For "models/embedding-001", it's 768.
final_embeddings.append([0.0] * 768) # Placeholder for invalid inputs
return final_embeddings
except Exception as e:
error_message = f"获取文本嵌入时出错 (GeminiChromaEF): {e}"
st.error(error_message)
print(error_message)
# Return dummy embeddings for all inputs in case of a general error
return [[0.0] * 768 for _ in input_texts] # Placeholder dimension |