Nancy1906 commited on
Commit
26c8de0
·
verified ·
1 Parent(s): bffa65a
Files changed (1) hide show
  1. my_tools.py +32 -11
my_tools.py CHANGED
@@ -41,6 +41,11 @@ class GeminiLLM(LLM):
41
  model_name=self.model_name,
42
  generation_config=self._gen_cfg
43
  )
 
 
 
 
 
44
  if not self.callback_manager.handlers:
45
  self.callback_manager.add_handler(LlamaDebugHandler())
46
 
@@ -54,7 +59,7 @@ class GeminiLLM(LLM):
54
  model_name=self.model_name,
55
  )
56
 
57
- def chat(self, messages, **kwargs):
58
  hist = []
59
  for m in messages[:-1]:
60
  role = "user" if m.role == "user" else "model"
@@ -65,12 +70,12 @@ class GeminiLLM(LLM):
65
  resp = session.send_message(last)
66
  return ChatMessage(role="assistant", content=resp.text)
67
  except Exception as e:
68
- return ChatMessage(role="assistant", content=f"Error Gemini: {e}")
69
 
70
- async def achat(self, messages, **kwargs):
71
  return await asyncio.to_thread(self.chat, messages, **kwargs)
72
 
73
- def stream_complete(self, prompt, formatted=False, **kwargs):
74
  stream = self._model.generate_content(str(prompt), stream=True)
75
  def gen():
76
  acc = ""
@@ -83,10 +88,17 @@ class GeminiLLM(LLM):
83
  yield CompletionResponse(text=acc, delta=delta)
84
  return gen()
85
 
86
- async def astream_complete(self, prompt, formatted=False, **kwargs):
87
- return await asyncio.to_thread(self.stream_complete, prompt, formatted=formatted, **kwargs)
 
 
 
 
 
 
88
 
89
- def stream_chat(self, messages, **kwargs):
 
90
  hist = []
91
  for m in messages[:-1]:
92
  role = "user" if m.role == "user" else "model"
@@ -105,16 +117,25 @@ class GeminiLLM(LLM):
105
  yield ChatMessage(role="assistant", content=acc, additional_kwargs={"delta": delta})
106
  return gen()
107
 
108
- def complete(self, prompt, formatted=False, **kwargs):
 
 
 
 
 
 
 
 
 
 
109
  try:
110
  resp = self._model.generate_content(str(prompt))
111
  return CompletionResponse(text=resp.text)
112
  except Exception as e:
113
- return CompletionResponse(text=f"Error complete: {e}")
114
 
115
- async def acomplete(self, prompt, formatted=False, **kwargs):
116
  return await asyncio.to_thread(self.complete, prompt, formatted=formatted, **kwargs)
117
-
118
  # -------------------------------------------------------------------
119
  # 2) Herramientas
120
  # -------------------------------------------------------------------
 
41
  model_name=self.model_name,
42
  generation_config=self._gen_cfg
43
  )
44
+ # Inicializar callback_manager si no se pasa en kwargs
45
+ if self.callback_manager is None:
46
+ from llama_index.core.callbacks.base import CallbackManager
47
+ self.callback_manager = CallbackManager([])
48
+
49
  if not self.callback_manager.handlers:
50
  self.callback_manager.add_handler(LlamaDebugHandler())
51
 
 
59
  model_name=self.model_name,
60
  )
61
 
62
+ def chat(self, messages: list[ChatMessage], **kwargs): # Añadido tipo para messages
63
  hist = []
64
  for m in messages[:-1]:
65
  role = "user" if m.role == "user" else "model"
 
70
  resp = session.send_message(last)
71
  return ChatMessage(role="assistant", content=resp.text)
72
  except Exception as e:
73
+ return ChatMessage(role="assistant", content=f"Error Gemini chat: {e}")
74
 
75
+ async def achat(self, messages: list[ChatMessage], **kwargs): # Añadido tipo para messages
76
  return await asyncio.to_thread(self.chat, messages, **kwargs)
77
 
78
+ def stream_complete(self, prompt: str, formatted=False, **kwargs): # Añadido tipo para prompt
79
  stream = self._model.generate_content(str(prompt), stream=True)
80
  def gen():
81
  acc = ""
 
88
  yield CompletionResponse(text=acc, delta=delta)
89
  return gen()
90
 
91
+ async def astream_complete(self, prompt: str, formatted=False, **kwargs): # Añadido tipo para prompt
92
+ # Correctamente, esto debería devolver un generador asíncrono.
93
+ # Envolver el generador síncrono es un workaround común.
94
+ sync_gen = await asyncio.to_thread(self.stream_complete, prompt, formatted=formatted, **kwargs)
95
+ async def async_gen_wrapper():
96
+ for item in sync_gen:
97
+ yield item
98
+ return async_gen_wrapper()
99
 
100
+
101
+ def stream_chat(self, messages: list[ChatMessage], **kwargs): # Añadido tipo para messages
102
  hist = []
103
  for m in messages[:-1]:
104
  role = "user" if m.role == "user" else "model"
 
117
  yield ChatMessage(role="assistant", content=acc, additional_kwargs={"delta": delta})
118
  return gen()
119
 
120
+ # --- MÉTODO FALTANTE AÑADIDO AQUÍ ---
121
+ async def astream_chat(self, messages: list[ChatMessage], **kwargs): # Añadido tipo para messages
122
+ # Similar a astream_complete, envolvemos el generador síncrono
123
+ sync_gen = await asyncio.to_thread(self.stream_chat, messages, **kwargs)
124
+ async def async_gen_wrapper():
125
+ for item in sync_gen:
126
+ yield item
127
+ return async_gen_wrapper()
128
+ # --- FIN DEL MÉTODO AÑADIDO ---
129
+
130
+ def complete(self, prompt: str, formatted=False, **kwargs): # Añadido tipo para prompt
131
  try:
132
  resp = self._model.generate_content(str(prompt))
133
  return CompletionResponse(text=resp.text)
134
  except Exception as e:
135
+ return CompletionResponse(text=f"Error Gemini complete: {e}")
136
 
137
+ async def acomplete(self, prompt: str, formatted=False, **kwargs): # Añadido tipo para prompt
138
  return await asyncio.to_thread(self.complete, prompt, formatted=formatted, **kwargs)
 
139
  # -------------------------------------------------------------------
140
  # 2) Herramientas
141
  # -------------------------------------------------------------------