Spaces:
Paused
Paused
| # llm_handler.py | |
| import google.generativeai as genai | |
| from config import GOOGLE_API_KEY, GENERATIVE_MODEL, EMBEDDING_MODEL | |
| import streamlit as st # For displaying errors or warnings if needed | |
| # Configure Gemini API | |
| if GOOGLE_API_KEY: | |
| try: | |
| genai.configure(api_key=GOOGLE_API_KEY) | |
| except Exception as e: | |
| st.error(f"Failed to configure Gemini API: {e}") # Show error in Streamlit if app is running | |
| print(f"Failed to configure Gemini API: {e}") # Print to console for server logs | |
| else: | |
| # This will be handled by Streamlit UI in app.py if key is missing | |
| print("Warning: GOOGLE_API_KEY is not set. LLM features will not work.") | |
| def get_gemini_response(prompt_text, system_instruction=None): | |
| """获取Gemini模型的响应""" | |
| if not GOOGLE_API_KEY: | |
| st.error("Gemini API Key未配置,无法获取模型响应。请在Hugging Face Space Secrets中设置 GOOGLE_API_KEY。") | |
| return None | |
| try: | |
| model = genai.GenerativeModel( | |
| GENERATIVE_MODEL, | |
| system_instruction=system_instruction if system_instruction else None | |
| ) | |
| response = model.generate_content(prompt_text) | |
| return response.text | |
| except Exception as e: | |
| error_message = f"与Gemini通信时出错: {e}" | |
| if hasattr(e, 'message') and "API key not valid" in e.message: | |
| error_message = "Gemini API Key无效或权限不足。请检查Hugging Face Space Secrets中的GOOGLE_API_KEY。" | |
| st.error(error_message) | |
| print(error_message) # For server logs | |
| return None | |
| # Using genai.embed_content directly is often simpler for ChromaDB | |
| # but if you need a callable for ChromaDB's embedding_functions parameter: | |
| class GeminiEmbeddingFunctionForChroma(genai.embedding.EmbeddingFunction): | |
| def __call__(self, input: genai.embedding.EmbedContentRequest) -> genai.embedding.EmbedContentResponse: | |
| # Ensure 'input' is a list of strings (documents) | |
| if not isinstance(input, list) or not all(isinstance(doc, str) for doc in input): | |
| # ChromaDB typically passes a list of documents (strings) | |
| # genai.embed_content expects a 'content' field which can be a string or list of strings | |
| # The structure of 'input' from ChromaDB needs to be correctly mapped. | |
| # ChromaDB's `embedding_function` interface expects a function that takes a list of texts | |
| # and returns a list of embeddings. | |
| # Let's assume 'input' is a list of document strings. | |
| docs_to_embed = input | |
| else: # Fallback if input structure is different, adapt as needed | |
| docs_to_embed = [str(item) for item in input] | |
| if not docs_to_embed: | |
| return {"embedding": []} # Return empty embedding list for empty input | |
| try: | |
| # Embed a batch of documents. | |
| # `task_type` is important for retrieval. | |
| result = genai.embed_content( | |
| model=EMBEDDING_MODEL, | |
| content=docs_to_embed, | |
| task_type="RETRIEVAL_DOCUMENT" | |
| ) | |
| return result['embedding'] # ChromaDB expects a list of embeddings | |
| except Exception as e: | |
| error_message = f"获取文本嵌入时出错: {e}" | |
| st.error(error_message) | |
| print(error_message) | |
| # Return a list of Nones or empty lists of the correct length if an error occurs for some documents | |
| return [None] * len(docs_to_embed) | |
| # --- Alternative simpler embedding function for ChromaDB --- | |
| # This is often easier to integrate if ChromaDB's embedding_function | |
| # parameter expects a function that takes a list of texts. | |
| from chromadb import Documents, EmbeddingFunction, Embeddings | |
| class GeminiChromaEF(EmbeddingFunction): | |
| def __init__(self, model_name: str = EMBEDDING_MODEL, task_type: str = "RETRIEVAL_DOCUMENT"): | |
| self._model_name = model_name | |
| self._task_type = task_type | |
| if not GOOGLE_API_KEY: | |
| print("Warning: GOOGLE_API_KEY not set. Embedding function might fail.") | |
| # Optionally raise an error or handle appropriately | |
| def __call__(self, input_texts: Documents) -> Embeddings: | |
| if not GOOGLE_API_KEY: | |
| st.error("Gemini API Key未配置,无法生成文本嵌入。") | |
| print("Gemini API Key not configured for embeddings.") | |
| return [([0.0] * 768) for _ in input_texts] # Return dummy embeddings or handle error | |
| if not input_texts: | |
| return [] | |
| try: | |
| # Filter out any None or non-string inputs, though Documents type should be list of str | |
| valid_texts = [text for text in input_texts if isinstance(text, str)] | |
| if not valid_texts: | |
| # Handle case where all inputs were invalid | |
| return [([0.0] * 768) for _ in input_texts] | |
| result = genai.embed_content( | |
| model=self._model_name, | |
| content=valid_texts, | |
| task_type=self._task_type | |
| ) | |
| # Ensure the result matches the number of valid_texts. | |
| # If there was an error, result['embedding'] might be shorter or None. | |
| # A robust handler would map results back to original input count, perhaps with None for errors. | |
| # For simplicity here, assuming success or a catastrophic failure handled by the try-except. | |
| # Map embeddings back to the original input_texts length, filling with None for invalid ones | |
| # This part is tricky because genai.embed_content might error out entirely or skip bad inputs. | |
| # Let's assume it returns embeddings for valid_texts only. | |
| embeddings_dict = {text: emb for text, emb in zip(valid_texts, result['embedding'])} | |
| final_embeddings = [] | |
| for text in input_texts: | |
| if isinstance(text, str) and text in embeddings_dict: | |
| final_embeddings.append(embeddings_dict[text]) | |
| else: | |
| # Provide a dummy embedding or None for invalid/missing inputs | |
| # The dimension (e.g., 768) depends on your embedding model. | |
| # For "models/embedding-001", it's 768. | |
| final_embeddings.append([0.0] * 768) # Placeholder for invalid inputs | |
| return final_embeddings | |
| except Exception as e: | |
| error_message = f"获取文本嵌入时出错 (GeminiChromaEF): {e}" | |
| st.error(error_message) | |
| print(error_message) | |
| # Return dummy embeddings for all inputs in case of a general error | |
| return [[0.0] * 768 for _ in input_texts] # Placeholder dimension |