Nancy1906 commited on
Commit
380e830
·
verified ·
1 Parent(s): bb7dccf
Files changed (1) hide show
  1. my_tools.py +92 -26
my_tools.py CHANGED
@@ -47,66 +47,132 @@ ChatMessage.message = property(lambda self: self)
47
 
48
  # ---------- GEMINI LLM ----------
49
  class GeminiLLM(LLM):
50
- ...
51
- # ← aquí ya tienes __init__, metadata, chat, achat, complete, acomplete
52
- # ⬇️ pega estos métodos faltantes ⬇️
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  def stream_complete(self, prompt: str, formatted: bool = False, **kwargs):
55
- """Devuelve un generador incremental de CompletionResponse."""
56
- stream = self._model.generate_content(str(prompt), stream=True)
57
 
58
- def gen():
 
59
  acc = ""
60
- from llama_index.core.llms import CompletionResponse # import local
61
  for chunk in stream:
62
- delta = getattr(chunk, "text", "") or (
63
- chunk.parts[0].text if getattr(chunk, "parts", None) else ""
64
- )
65
  if delta:
66
  acc += delta
67
  yield CompletionResponse(text=acc, delta=delta)
68
 
69
- return gen()
70
 
71
  async def astream_complete(self, prompt: str, formatted: bool = False, **kwargs):
72
- # ejecuta la versión síncrona en un hilo
73
  sync_gen = await asyncio.to_thread(self.stream_complete, prompt, formatted=formatted, **kwargs)
74
 
75
- async def async_gen():
76
  for item in sync_gen:
77
  yield item
78
 
79
- return async_gen()
80
 
 
 
 
81
  def stream_chat(self, messages: list[ChatMessage], **kwargs):
82
- hist = [
83
  {"role": "user" if m.role == "user" else "model", "parts": [{"text": str(m.content)}]}
84
  for m in messages[:-1]
85
  ]
86
- last = str(messages[-1].content)
87
- session = self._model.start_chat(history=hist)
88
- stream = session.send_message(last, stream=True)
89
 
90
- def gen():
91
  acc = ""
92
  for chunk in stream:
93
- delta = getattr(chunk, "text", "") or (
94
- chunk.parts[0].text if getattr(chunk, "parts", None) else ""
95
- )
96
  if delta:
97
  acc += delta
98
- yield ChatMessage(role="assistant", content=acc, additional_kwargs={"delta": delta})
 
 
 
 
99
 
100
- return gen()
101
 
102
  async def astream_chat(self, messages: list[ChatMessage], **kwargs):
103
  sync_gen = await asyncio.to_thread(self.stream_chat, messages, **kwargs)
104
 
105
- async def async_gen():
106
  for item in sync_gen:
107
  yield item
108
 
109
- return async_gen()
 
110
 
111
 
112
 
 
47
 
48
  # ---------- GEMINI LLM ----------
49
  class GeminiLLM(LLM):
50
+ """Wrapper mínimo para Gemini 1.5 que satisface la interfaz de Llama-Index."""
 
 
51
 
52
+ model_name: str = Field(default="models/gemini-1.5-flash-latest")
53
+ temperature: float = Field(default=0.0)
54
+
55
+ # -- inicialización -----------------------------------------------------
56
+ def __init__(self, **kwargs):
57
+ super().__init__(**kwargs)
58
+
59
+ api_key = os.getenv("GEMINI_API_KEY")
60
+ if not api_key:
61
+ raise ValueError("GEMINI_API_KEY no configurada en variables de entorno")
62
+ genai.configure(api_key=api_key)
63
+
64
+ self._model = genai.GenerativeModel(
65
+ model_name=self.model_name,
66
+ generation_config=genai.types.GenerationConfig(
67
+ temperature=float(self.temperature)
68
+ ),
69
+ )
70
+
71
+ # callback manager defensivo
72
+ if self.callback_manager is None:
73
+ from llama_index.core.callbacks.base import CallbackManager
74
+ self.callback_manager = CallbackManager([])
75
+ if not self.callback_manager.handlers:
76
+ self.callback_manager.add_handler(LlamaDebugHandler())
77
+
78
+ # -- metadatos ----------------------------------------------------------
79
+ @property
80
+ def metadata(self) -> LLMMetadata: # type: ignore[override]
81
+ return LLMMetadata(
82
+ context_window=1_048_576,
83
+ num_output=8192,
84
+ is_chat_model=True,
85
+ is_function_calling_model=True,
86
+ model_name=self.model_name,
87
+ )
88
+
89
+ # ----------------------------------------------------------------------
90
+ # 1️⃣ CHAT SINCRONO
91
+ # ----------------------------------------------------------------------
92
+ def chat(self, messages: list[ChatMessage], **kwargs) -> ChatMessage: # type: ignore[override]
93
+ history = [
94
+ {"role": "user" if m.role == "user" else "model", "parts": [{"text": str(m.content)}]}
95
+ for m in messages[:-1]
96
+ ]
97
+ session = self._model.start_chat(history=history)
98
+ reply = session.send_message(str(messages[-1].content))
99
+ return ChatMessage(role="assistant", content=reply.text)
100
+
101
+ # 1-bis CHAT ASINCRONO
102
+ async def achat(self, messages: list[ChatMessage], **kwargs) -> ChatMessage: # type: ignore[override]
103
+ return await asyncio.to_thread(self.chat, messages, **kwargs)
104
+
105
+ # ----------------------------------------------------------------------
106
+ # 2️⃣ COMPLETE SINCRONO (prompt plano)
107
+ # ----------------------------------------------------------------------
108
+ def complete(self, prompt: str, formatted: bool = False, **kwargs) -> CompletionResponse: # type: ignore[override]
109
+ resp = self._model.generate_content(prompt)
110
+ return CompletionResponse(text=resp.text)
111
+
112
+ # 2-bis COMPLETE ASINCRONO
113
+ async def acomplete(self, prompt: str, formatted: bool = False, **kwargs) -> CompletionResponse: # type: ignore[override]
114
+ return await asyncio.to_thread(self.complete, prompt, formatted=formatted, **kwargs)
115
+
116
+ # ----------------------------------------------------------------------
117
+ # 3️⃣ STREAMING DE COMPLETIONS
118
+ # ----------------------------------------------------------------------
119
  def stream_complete(self, prompt: str, formatted: bool = False, **kwargs):
120
+ stream = self._model.generate_content(prompt, stream=True)
 
121
 
122
+ def generator():
123
+ from llama_index.core.llms import CompletionResponse
124
  acc = ""
 
125
  for chunk in stream:
126
+ delta = getattr(chunk, "text", "") or (chunk.parts[0].text if chunk.parts else "")
 
 
127
  if delta:
128
  acc += delta
129
  yield CompletionResponse(text=acc, delta=delta)
130
 
131
+ return generator()
132
 
133
  async def astream_complete(self, prompt: str, formatted: bool = False, **kwargs):
 
134
  sync_gen = await asyncio.to_thread(self.stream_complete, prompt, formatted=formatted, **kwargs)
135
 
136
+ async def agen():
137
  for item in sync_gen:
138
  yield item
139
 
140
+ return agen()
141
 
142
+ # ----------------------------------------------------------------------
143
+ # 4️⃣ STREAMING DE CHAT
144
+ # ----------------------------------------------------------------------
145
  def stream_chat(self, messages: list[ChatMessage], **kwargs):
146
+ history = [
147
  {"role": "user" if m.role == "user" else "model", "parts": [{"text": str(m.content)}]}
148
  for m in messages[:-1]
149
  ]
150
+ session = self._model.start_chat(history=history)
151
+ stream = session.send_message(str(messages[-1].content), stream=True)
 
152
 
153
+ def generator():
154
  acc = ""
155
  for chunk in stream:
156
+ delta = getattr(chunk, "text", "") or (chunk.parts[0].text if chunk.parts else "")
 
 
157
  if delta:
158
  acc += delta
159
+ yield ChatMessage(
160
+ role="assistant",
161
+ content=acc,
162
+ additional_kwargs={"delta": delta},
163
+ )
164
 
165
+ return generator()
166
 
167
  async def astream_chat(self, messages: list[ChatMessage], **kwargs):
168
  sync_gen = await asyncio.to_thread(self.stream_chat, messages, **kwargs)
169
 
170
+ async def agen():
171
  for item in sync_gen:
172
  yield item
173
 
174
+ return agen()
175
+
176
 
177
 
178