Nancy1906 commited on
Commit
06d204c
·
verified ·
1 Parent(s): 26c8de0
Files changed (1) hide show
  1. my_tools.py +64 -19
my_tools.py CHANGED
@@ -27,21 +27,57 @@ class GeminiLLM(LLM):
27
  model_name: str = Field(default="models/gemini-1.5-flash-latest")
28
  temperature: float = Field(default=0.0)
29
 
 
 
 
 
 
30
  class Config:
31
  extra = "allow"
32
 
33
  def __init__(self, **kwargs):
34
- super().__init__(**kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  key = os.getenv("GEMINI_API_KEY")
36
  if not key:
37
  raise ValueError("GEMINI_API_KEY no configurada")
38
  genai.configure(api_key=key)
39
- self._gen_cfg = genai.types.GenerationConfig(temperature=self.temperature)
 
40
  self._model = genai.GenerativeModel(
41
- model_name=self.model_name,
42
  generation_config=self._gen_cfg
43
  )
44
- # Inicializar callback_manager si no se pasa en kwargs
45
  if self.callback_manager is None:
46
  from llama_index.core.callbacks.base import CallbackManager
47
  self.callback_manager = CallbackManager([])
@@ -51,15 +87,28 @@ class GeminiLLM(LLM):
51
 
52
  @property
53
  def metadata(self):
 
 
 
 
 
 
 
 
 
54
  return LLMMetadata(
55
  context_window=1048576,
56
  num_output=8192,
57
  is_chat_model=True,
58
  is_function_calling_model=True,
59
- model_name=self.model_name,
60
  )
61
 
62
- def chat(self, messages: list[ChatMessage], **kwargs): # Añadido tipo para messages
 
 
 
 
63
  hist = []
64
  for m in messages[:-1]:
65
  role = "user" if m.role == "user" else "model"
@@ -72,10 +121,10 @@ class GeminiLLM(LLM):
72
  except Exception as e:
73
  return ChatMessage(role="assistant", content=f"Error Gemini chat: {e}")
74
 
75
- async def achat(self, messages: list[ChatMessage], **kwargs): # Añadido tipo para messages
76
  return await asyncio.to_thread(self.chat, messages, **kwargs)
77
 
78
- def stream_complete(self, prompt: str, formatted=False, **kwargs): # Añadido tipo para prompt
79
  stream = self._model.generate_content(str(prompt), stream=True)
80
  def gen():
81
  acc = ""
@@ -88,17 +137,14 @@ class GeminiLLM(LLM):
88
  yield CompletionResponse(text=acc, delta=delta)
89
  return gen()
90
 
91
- async def astream_complete(self, prompt: str, formatted=False, **kwargs): # Añadido tipo para prompt
92
- # Correctamente, esto debería devolver un generador asíncrono.
93
- # Envolver el generador síncrono es un workaround común.
94
  sync_gen = await asyncio.to_thread(self.stream_complete, prompt, formatted=formatted, **kwargs)
95
  async def async_gen_wrapper():
96
  for item in sync_gen:
97
  yield item
98
  return async_gen_wrapper()
99
 
100
-
101
- def stream_chat(self, messages: list[ChatMessage], **kwargs): # Añadido tipo para messages
102
  hist = []
103
  for m in messages[:-1]:
104
  role = "user" if m.role == "user" else "model"
@@ -117,25 +163,24 @@ class GeminiLLM(LLM):
117
  yield ChatMessage(role="assistant", content=acc, additional_kwargs={"delta": delta})
118
  return gen()
119
 
120
- # --- MÉTODO FALTANTE AÑADIDO AQUÍ ---
121
- async def astream_chat(self, messages: list[ChatMessage], **kwargs): # Añadido tipo para messages
122
- # Similar a astream_complete, envolvemos el generador síncrono
123
  sync_gen = await asyncio.to_thread(self.stream_chat, messages, **kwargs)
124
  async def async_gen_wrapper():
125
  for item in sync_gen:
126
  yield item
127
  return async_gen_wrapper()
128
- # --- FIN DEL MÉTODO AÑADIDO ---
129
 
130
- def complete(self, prompt: str, formatted=False, **kwargs): # Añadido tipo para prompt
131
  try:
132
  resp = self._model.generate_content(str(prompt))
133
  return CompletionResponse(text=resp.text)
134
  except Exception as e:
135
  return CompletionResponse(text=f"Error Gemini complete: {e}")
136
 
137
- async def acomplete(self, prompt: str, formatted=False, **kwargs): # Añadido tipo para prompt
138
  return await asyncio.to_thread(self.complete, prompt, formatted=formatted, **kwargs)
 
 
139
  # -------------------------------------------------------------------
140
  # 2) Herramientas
141
  # -------------------------------------------------------------------
 
27
  model_name: str = Field(default="models/gemini-1.5-flash-latest")
28
  temperature: float = Field(default=0.0)
29
 
30
+ # Atributos para el modelo y config de generación.
31
+ # Pydantic los ignorará si no son Fields y Config.extra = "allow" (lo cual tienes)
32
+ _model: object = None
33
+ _gen_cfg: object = None
34
+
35
  class Config:
36
  extra = "allow"
37
 
38
  def __init__(self, **kwargs):
39
+ super().__init__(**kwargs) # Pydantic procesa campos y kwargs
40
+
41
+ # --- INICIO DE LA CORRECCIÓN PARA FieldInfo ---
42
+ # Obtener el valor resuelto de model_name explícitamente
43
+ # Primero, intentar con el atributo de instancia (que Pydantic debería haber establecido)
44
+ actual_model_name = self.model_name
45
+
46
+ # Si sigue siendo un FieldInfo (o no es un string), obtener el valor default del campo
47
+ if not isinstance(actual_model_name, str):
48
+ # Acceder a la definición del campo de la clase para obtener su default
49
+ # self.__fields__ es un dict de los campos Pydantic de la clase
50
+ model_field_definition = self.__fields__.get("model_name")
51
+ if model_field_definition and hasattr(model_field_definition, 'default'):
52
+ actual_model_name = model_field_definition.default
53
+
54
+ # Como última salvaguarda, si todo falla, usar un string literal (no ideal)
55
+ if not isinstance(actual_model_name, str):
56
+ # print("ADVERTENCIA: model_name no se pudo resolver a un string, usando valor literal.")
57
+ actual_model_name = "models/gemini-1.5-flash-latest"
58
+
59
+ # Lo mismo para temperature, aunque es menos probable que sea un FieldInfo aquí
60
+ actual_temperature = self.temperature
61
+ if not isinstance(actual_temperature, (float, int)):
62
+ temp_field_definition = self.__fields__.get("temperature")
63
+ if temp_field_definition and hasattr(temp_field_definition, 'default'):
64
+ actual_temperature = temp_field_definition.default
65
+ if not isinstance(actual_temperature, (float, int)):
66
+ # print("ADVERTENCIA: temperature no se pudo resolver a un float, usando 0.0.")
67
+ actual_temperature = 0.0
68
+ # --- FIN DE LA CORRECCIÓN PARA FieldInfo ---
69
+
70
  key = os.getenv("GEMINI_API_KEY")
71
  if not key:
72
  raise ValueError("GEMINI_API_KEY no configurada")
73
  genai.configure(api_key=key)
74
+
75
+ self._gen_cfg = genai.types.GenerationConfig(temperature=actual_temperature)
76
  self._model = genai.GenerativeModel(
77
+ model_name=actual_model_name, # Usar el valor de string resuelto
78
  generation_config=self._gen_cfg
79
  )
80
+
81
  if self.callback_manager is None:
82
  from llama_index.core.callbacks.base import CallbackManager
83
  self.callback_manager = CallbackManager([])
 
87
 
88
  @property
89
  def metadata(self):
90
+ # También asegurar que model_name es un string aquí
91
+ actual_model_name_meta = self.model_name
92
+ if not isinstance(actual_model_name_meta, str):
93
+ model_field_def_meta = self.__fields__.get("model_name")
94
+ if model_field_def_meta and hasattr(model_field_def_meta, 'default'):
95
+ actual_model_name_meta = model_field_def_meta.default
96
+ if not isinstance(actual_model_name_meta, str):
97
+ actual_model_name_meta = "models/gemini-1.5-flash-latest" # Fallback
98
+
99
  return LLMMetadata(
100
  context_window=1048576,
101
  num_output=8192,
102
  is_chat_model=True,
103
  is_function_calling_model=True,
104
+ model_name=actual_model_name_meta, # Usar el valor de string resuelto
105
  )
106
 
107
+ # ... (todos los demás métodos: chat, achat, stream_complete, astream_complete, stream_chat, astream_chat, complete, acomplete)
108
+ # DEBEN ESTAR EXACTAMENTE COMO EN TU ÚLTIMA VERSIÓN FUNCIONAL DEL CÓDIGO QUE ME PEGASTE.
109
+ # Los copio de tu último fragmento para asegurar consistencia:
110
+
111
+ def chat(self, messages: list[ChatMessage], **kwargs):
112
  hist = []
113
  for m in messages[:-1]:
114
  role = "user" if m.role == "user" else "model"
 
121
  except Exception as e:
122
  return ChatMessage(role="assistant", content=f"Error Gemini chat: {e}")
123
 
124
+ async def achat(self, messages: list[ChatMessage], **kwargs):
125
  return await asyncio.to_thread(self.chat, messages, **kwargs)
126
 
127
+ def stream_complete(self, prompt: str, formatted=False, **kwargs):
128
  stream = self._model.generate_content(str(prompt), stream=True)
129
  def gen():
130
  acc = ""
 
137
  yield CompletionResponse(text=acc, delta=delta)
138
  return gen()
139
 
140
+ async def astream_complete(self, prompt: str, formatted=False, **kwargs):
 
 
141
  sync_gen = await asyncio.to_thread(self.stream_complete, prompt, formatted=formatted, **kwargs)
142
  async def async_gen_wrapper():
143
  for item in sync_gen:
144
  yield item
145
  return async_gen_wrapper()
146
 
147
+ def stream_chat(self, messages: list[ChatMessage], **kwargs):
 
148
  hist = []
149
  for m in messages[:-1]:
150
  role = "user" if m.role == "user" else "model"
 
163
  yield ChatMessage(role="assistant", content=acc, additional_kwargs={"delta": delta})
164
  return gen()
165
 
166
+ async def astream_chat(self, messages: list[ChatMessage], **kwargs):
 
 
167
  sync_gen = await asyncio.to_thread(self.stream_chat, messages, **kwargs)
168
  async def async_gen_wrapper():
169
  for item in sync_gen:
170
  yield item
171
  return async_gen_wrapper()
 
172
 
173
+ def complete(self, prompt: str, formatted=False, **kwargs):
174
  try:
175
  resp = self._model.generate_content(str(prompt))
176
  return CompletionResponse(text=resp.text)
177
  except Exception as e:
178
  return CompletionResponse(text=f"Error Gemini complete: {e}")
179
 
180
+ async def acomplete(self, prompt: str, formatted=False, **kwargs):
181
  return await asyncio.to_thread(self.complete, prompt, formatted=formatted, **kwargs)
182
+
183
+ # --- Fin de la clase GeminiLLM ---
184
  # -------------------------------------------------------------------
185
  # 2) Herramientas
186
  # -------------------------------------------------------------------