VietCat commited on
Commit
3c8c274
·
1 Parent(s): dfd32d8

refactor llm/embedding flow

Browse files
Files changed (1) hide show
  1. app/gemini_client.py +45 -0
app/gemini_client.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from google.generativeai.embedding import embed_content
2
+ from google.generativeai.client import configure
3
+ from google.generativeai.generative_models import GenerativeModel
4
+ from loguru import logger
5
+
6
+ class GeminiClient:
7
+ def __init__(self, api_key: str, model: str = "gemini-1.5-flash-latest"):
8
+ self.api_key = api_key
9
+ self.model = model
10
+ configure(api_key=api_key)
11
+ self._model = GenerativeModel(model)
12
+
13
+ def generate_text(self, prompt: str, **kwargs) -> str:
14
+ try:
15
+ response = self._model.generate_content(prompt, **kwargs)
16
+ logger.info(f"[GEMINI][RAW_RESPONSE] {response}")
17
+ # response có thể là GenerativeModelResponse, lấy text hoặc trả về str
18
+ if hasattr(response, 'text'):
19
+ return response.text
20
+ elif hasattr(response, 'candidates') and response.candidates:
21
+ return response.candidates[0].content.parts[0].text
22
+ return str(response)
23
+ except Exception as e:
24
+ logger.error(f"[GEMINI] Error: {e}")
25
+ raise
26
+
27
+ def count_tokens(self, prompt: str) -> int:
28
+ try:
29
+ return self._model.count_tokens(prompt).total_tokens
30
+ except Exception as e:
31
+ logger.error(f"[GEMINI] Token count error: {e}")
32
+ return 0
33
+
34
+ def create_embedding(self, text: str) -> list:
35
+ try:
36
+ response = embed_content(
37
+ model=self.model,
38
+ content=text,
39
+ task_type="retrieval_document"
40
+ )
41
+ logger.info(f"[GEMINI][EMBEDDING][RAW_RESPONSE] {response}")
42
+ return response['embedding']
43
+ except Exception as e:
44
+ logger.error(f"[GEMINI][EMBEDDING] Error: {e}")
45
+ raise