File size: 7,281 Bytes
3c8c274
 
 
 
eab288c
8812f42
e830260
 
 
3c8c274
 
eab288c
 
44013a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3c8c274
906da16
 
 
 
 
 
 
 
 
 
3c8c274
eab288c
44013a5
 
 
eab288c
44013a5
 
 
 
 
 
eab288c
 
44013a5
eab288c
 
44013a5
eab288c
e529ed6
eab288c
 
e529ed6
eab288c
44013a5
eab288c
 
44013a5
eab288c
 
 
 
 
 
 
 
44013a5
906da16
eab288c
44013a5
906da16
 
44013a5
 
 
 
 
 
 
 
 
 
eab288c
3c8c274
 
44013a5
 
 
 
 
 
 
3c8c274
6723e05
eab288c
44013a5
 
 
eab288c
44013a5
 
 
 
 
 
 
 
6723e05
44013a5
eab288c
 
8812f42
eab288c
6723e05
eab288c
44013a5
8812f42
eab288c
 
44013a5
eab288c
 
 
 
 
8812f42
 
 
44013a5
 
8812f42
44013a5
 
 
 
 
 
 
 
 
eab288c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
from google.generativeai.embedding import embed_content
from google.generativeai.client import configure
from google.generativeai.generative_models import GenerativeModel
from loguru import logger
from .request_limit_manager import RequestLimitManager
from typing import List, Optional
from .utils import (
    _safe_truncate
)

class GeminiClient:
    def __init__(self):
        self.limit_manager = RequestLimitManager("gemini")
        self._cached_model = None
        self._cached_key = None
        self._cached_model_instance = None

    def _get_model_instance(self, key: str, model: str):
        """
        Cache model instance để tránh recreate mỗi lần.
        """
        if (self._cached_key == key and 
            self._cached_model == model and 
            self._cached_model_instance is not None):
            return self._cached_model_instance
        
        # Configure và tạo model instance mới
        configure(api_key=key)
        self._cached_model_instance = GenerativeModel(model)
        self._cached_key = key
        self._cached_model = model
        
        logger.info(f"[GEMINI] Created new model instance for key={key[:5]}...{key[-5:]} model={model}")
        return self._cached_model_instance

    def _clear_cache_if_needed(self, new_key: str, new_model: str):
        """
        Chỉ clear cache khi key/model thực sự thay đổi.
        """
        if (self._cached_key != new_key or self._cached_model != new_model):
            logger.info(f"[GEMINI] Clearing cache due to key/model change: {self._cached_key}->{new_key}, {self._cached_model}->{new_model}")
            self._cached_model_instance = None
            self._cached_key = None
            self._cached_model = None

    def generate_text(self, prompt: str, **kwargs) -> str:
        last_error = None
        max_retries = 3
        
        for attempt in range(max_retries):
            try:
                # Lấy current key/model từ manager
                key, model = self.limit_manager.get_current_key_model()
                
                # Sử dụng cached model instance
                _model = self._get_model_instance(key, model)
                
                response = _model.generate_content(prompt, **kwargs)
                self.limit_manager.log_request(key, model, success=True)
                
                if hasattr(response, 'usage_metadata'):
                    logger.info(f"[GEMINI][USAGE] Prompt Token Count: {response.usage_metadata.prompt_token_count} - Candidate Token Count: {response.usage_metadata.candidates_token_count} - Total Token Count: {response.usage_metadata.total_token_count}")
                
                if hasattr(response, 'text'):
                    logger.info(f"[GEMINI][TEXT_RESPONSE] {_safe_truncate(response.text)}")
                    return response.text
                elif hasattr(response, 'candidates') and response.candidates:
                    logger.info(f"[GEMINI][CANDIDATES_RESPONSE] {_safe_truncate(response.candidates[0].content.parts[0].text)}")
                    return response.candidates[0].content.parts[0].text
                
                logger.info(f"[GEMINI][RAW_RESPONSE] {response}")
                return str(response)
                
            except Exception as e:
                import re
                msg = str(e)
                if "429" in msg or "rate limit" in msg.lower():
                    retry_delay = 60
                    m = re.search(r'retry_delay.*?seconds: (\d+)', msg)
                    if m:
                        retry_delay = int(m.group(1))
                    
                    # Log failure với key/model thực tế đang được sử dụng
                    self.limit_manager.log_request(key, model, success=False, retry_delay=retry_delay)
                    
                    # Chỉ clear cache nếu key/model thay đổi
                    # Không clear cache ngay lập tức để tránh recreate không cần thiết
                    
                    logger.warning(f"[GEMINI] Rate limit hit, will retry with new key/model (attempt {attempt + 1}/{max_retries})")
                    last_error = e
                    continue
                else:
                    # Lỗi khác không phải rate limit
                    logger.error(f"[GEMINI] Error generating text: {e}")
                    last_error = e
                    break
        
        raise last_error or RuntimeError("No available Gemini API key/model")

    def count_tokens(self, prompt: str) -> int:
        try:
            key, model = self.limit_manager.get_current_key_model()
            _model = self._get_model_instance(key, model)
            return _model.count_tokens(prompt).total_tokens
        except Exception as e:
            logger.error(f"[GEMINI] Error counting tokens: {e}")
            return 0

    def create_embedding(self, text: str, model: Optional[str] = None, task_type: str = "retrieval_query") -> list:
        last_error = None
        max_retries = 3
        
        for attempt in range(max_retries):
            try:
                key, default_model = self.limit_manager.get_current_key_model()
                
                # Ưu tiên model được truyền vào parameter, chỉ fallback về default_model nếu không có
                use_model = model if model and model.strip() else default_model
                
                if not use_model:
                    raise ValueError("No model specified for embedding")
                
                logger.info(f"[GEMINI][EMBEDDING] Using model={use_model} (requested={model}, default={default_model}), task_type={task_type}")
                
                configure(api_key=key)
                response = embed_content(
                    model=use_model,
                    content=text,
                    task_type=task_type
                )
                
                self.limit_manager.log_request(key, use_model, success=True)
                logger.info(f"[GEMINI][EMBEDDING][RAW_RESPONSE] {response['embedding'][:10]} ..... {response['embedding'][-10:]}")
                return response['embedding']
                
            except Exception as e:
                import re
                msg = str(e)
                if "429" in msg or "rate limit" in msg.lower():
                    retry_delay = 60
                    m_retry = re.search(r'retry_delay.*?seconds: (\d+)', msg)
                    if m_retry:
                        retry_delay = int(m_retry.group(1))
                    
                    # Log failure và trigger scan cho key/model mới
                    self.limit_manager.log_request(key, use_model, success=False, retry_delay=retry_delay)
                    
                    logger.warning(f"[GEMINI] Rate limit hit in embedding, will retry with new key/model (attempt {attempt + 1}/{max_retries})")
                    last_error = e
                    continue
                else:
                    logger.error(f"[GEMINI] Error creating embedding: {e}")
                    last_error = e
                    break
        
        raise last_error or RuntimeError("No available Gemini API key/model")