File size: 6,882 Bytes
634b5dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
# llm_handler.py
import google.generativeai as genai
from config import GOOGLE_API_KEY, GENERATIVE_MODEL, EMBEDDING_MODEL
import streamlit as st # For displaying errors or warnings if needed

# Configure Gemini API
if GOOGLE_API_KEY:
    try:
        genai.configure(api_key=GOOGLE_API_KEY)
    except Exception as e:
        st.error(f"Failed to configure Gemini API: {e}") # Show error in Streamlit if app is running
        print(f"Failed to configure Gemini API: {e}") # Print to console for server logs
else:
    # This will be handled by Streamlit UI in app.py if key is missing
    print("Warning: GOOGLE_API_KEY is not set. LLM features will not work.")


def get_gemini_response(prompt_text, system_instruction=None):
    """获取Gemini模型的响应"""
    if not GOOGLE_API_KEY:
        st.error("Gemini API Key未配置,无法获取模型响应。请在Hugging Face Space Secrets中设置 GOOGLE_API_KEY。")
        return None
    try:
        model = genai.GenerativeModel(
            GENERATIVE_MODEL,
            system_instruction=system_instruction if system_instruction else None
        )
        response = model.generate_content(prompt_text)
        return response.text
    except Exception as e:
        error_message = f"与Gemini通信时出错: {e}"
        if hasattr(e, 'message') and "API key not valid" in e.message:
             error_message = "Gemini API Key无效或权限不足。请检查Hugging Face Space Secrets中的GOOGLE_API_KEY。"
        st.error(error_message)
        print(error_message) # For server logs
        return None

# Using genai.embed_content directly is often simpler for ChromaDB
# but if you need a callable for ChromaDB's embedding_functions parameter:
class GeminiEmbeddingFunctionForChroma(genai.embedding.EmbeddingFunction):
    def __call__(self, input: genai.embedding.EmbedContentRequest) -> genai.embedding.EmbedContentResponse:
        # Ensure 'input' is a list of strings (documents)
        if not isinstance(input, list) or not all(isinstance(doc, str) for doc in input):
            # ChromaDB typically passes a list of documents (strings)
            # genai.embed_content expects a 'content' field which can be a string or list of strings
            # The structure of 'input' from ChromaDB needs to be correctly mapped.
            # ChromaDB's `embedding_function` interface expects a function that takes a list of texts
            # and returns a list of embeddings.

            # Let's assume 'input' is a list of document strings.
            docs_to_embed = input
        else: # Fallback if input structure is different, adapt as needed
            docs_to_embed = [str(item) for item in input]


        if not docs_to_embed:
            return {"embedding": []} # Return empty embedding list for empty input

        try:
            # Embed a batch of documents.
            # `task_type` is important for retrieval.
            result = genai.embed_content(
                model=EMBEDDING_MODEL,
                content=docs_to_embed,
                task_type="RETRIEVAL_DOCUMENT"
            )
            return result['embedding'] # ChromaDB expects a list of embeddings
        except Exception as e:
            error_message = f"获取文本嵌入时出错: {e}"
            st.error(error_message)
            print(error_message)
            # Return a list of Nones or empty lists of the correct length if an error occurs for some documents
            return [None] * len(docs_to_embed)

# --- Alternative simpler embedding function for ChromaDB ---
# This is often easier to integrate if ChromaDB's embedding_function
# parameter expects a function that takes a list of texts.
from chromadb import Documents, EmbeddingFunction, Embeddings

class GeminiChromaEF(EmbeddingFunction):
    def __init__(self, model_name: str = EMBEDDING_MODEL, task_type: str = "RETRIEVAL_DOCUMENT"):
        self._model_name = model_name
        self._task_type = task_type
        if not GOOGLE_API_KEY:
            print("Warning: GOOGLE_API_KEY not set. Embedding function might fail.")
            # Optionally raise an error or handle appropriately

    def __call__(self, input_texts: Documents) -> Embeddings:
        if not GOOGLE_API_KEY:
            st.error("Gemini API Key未配置,无法生成文本嵌入。")
            print("Gemini API Key not configured for embeddings.")
            return [([0.0] * 768) for _ in input_texts] # Return dummy embeddings or handle error

        if not input_texts:
            return []
        try:
            # Filter out any None or non-string inputs, though Documents type should be list of str
            valid_texts = [text for text in input_texts if isinstance(text, str)]
            if not valid_texts:
                # Handle case where all inputs were invalid
                return [([0.0] * 768) for _ in input_texts]


            result = genai.embed_content(
                model=self._model_name,
                content=valid_texts,
                task_type=self._task_type
            )
            # Ensure the result matches the number of valid_texts.
            # If there was an error, result['embedding'] might be shorter or None.
            # A robust handler would map results back to original input count, perhaps with None for errors.
            # For simplicity here, assuming success or a catastrophic failure handled by the try-except.
            
            # Map embeddings back to the original input_texts length, filling with None for invalid ones
            # This part is tricky because genai.embed_content might error out entirely or skip bad inputs.
            # Let's assume it returns embeddings for valid_texts only.
            embeddings_dict = {text: emb for text, emb in zip(valid_texts, result['embedding'])}
            
            final_embeddings = []
            for text in input_texts:
                if isinstance(text, str) and text in embeddings_dict:
                    final_embeddings.append(embeddings_dict[text])
                else:
                    # Provide a dummy embedding or None for invalid/missing inputs
                    # The dimension (e.g., 768) depends on your embedding model.
                    # For "models/embedding-001", it's 768.
                    final_embeddings.append([0.0] * 768) # Placeholder for invalid inputs
            return final_embeddings

        except Exception as e:
            error_message = f"获取文本嵌入时出错 (GeminiChromaEF): {e}"
            st.error(error_message)
            print(error_message)
            # Return dummy embeddings for all inputs in case of a general error
            return [[0.0] * 768 for _ in input_texts] # Placeholder dimension