Nancy1906 commited on
Commit
bb7dccf
·
verified ·
1 Parent(s): d059535
Files changed (1) hide show
  1. my_tools.py +56 -59
my_tools.py CHANGED
@@ -47,70 +47,67 @@ ChatMessage.message = property(lambda self: self)
47
 
48
  # ---------- GEMINI LLM ----------
49
  class GeminiLLM(LLM):
50
- model_name: str = Field(default="models/gemini-1.5-flash-latest")
51
- temperature: float = Field(default=0.0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
- _model = None
 
 
 
 
 
 
 
 
54
 
55
- class Config:
56
- extra = "allow"
57
 
58
- def __init__(self, **kwargs):
59
- super().__init__(**kwargs)
60
- api_key = os.getenv("GEMINI_API_KEY")
61
- if not api_key:
62
- raise ValueError("GEMINI_API_KEY not set in environment")
63
- genai.configure(api_key=api_key)
64
 
65
- self._model = genai.GenerativeModel(
66
- model_name=self.model_name, generation_config=genai.types.GenerationConfig(temperature=self.temperature)
67
- )
68
- if self.callback_manager is None:
69
- from llama_index.core.callbacks.base import CallbackManager
70
-
71
- self.callback_manager = CallbackManager([])
72
- if not self.callback_manager.handlers:
73
- self.callback_manager.add_handler(LlamaDebugHandler())
74
-
75
- # ----- metadata -----
76
- @property
77
- def metadata(self):
78
- return LLMMetadata(
79
- context_window=1_048_576,
80
- num_output=8192,
81
- is_chat_model=True,
82
- is_function_calling_model=True,
83
- model_name=self.model_name,
84
- )
85
 
86
- # ----- sync chat -----
87
- def chat(self, messages: list[ChatMessage], **kwargs) -> ChatMessage:
88
- history = [
89
- {"role": ("user" if m.role == "user" else "model"), "parts": [{"text": str(m.content)}]}
90
- for m in messages[:-1]
91
- ]
92
- last_user_msg = str(messages[-1].content)
93
- session = self._model.start_chat(history=history)
94
- try:
95
- response = session.send_message(last_user_msg)
96
- return ChatMessage(role="assistant", content=response.text)
97
- except Exception as exc:
98
- return ChatMessage(role="assistant", content=f"Error Gemini chat: {exc}")
99
-
100
- # ----- async chat -----
101
- async def achat(self, messages: list[ChatMessage], **kwargs):
102
- return await asyncio.to_thread(self.chat, messages, **kwargs)
103
-
104
- # ----- completion helpers (rarely used) -----
105
- def complete(self, prompt: str, formatted: bool = False, **kwargs):
106
- try:
107
- resp = self._model.generate_content(str(prompt))
108
- return CompletionResponse(text=resp.text)
109
- except Exception as exc:
110
- return CompletionResponse(text=f"Error Gemini complete: {exc}")
111
-
112
- async def acomplete(self, prompt: str, formatted: bool = False, **kwargs):
113
- return await asyncio.to_thread(self.complete, prompt, formatted=formatted, **kwargs)
114
 
115
 
116
  # ---------- TOOLING ----------
 
47
 
48
  # ---------- GEMINI LLM ----------
49
  class GeminiLLM(LLM):
50
+ ...
51
+ # aquí ya tienes __init__, metadata, chat, achat, complete, acomplete
52
+ # ⬇️ pega estos métodos faltantes ⬇️
53
+
54
+ def stream_complete(self, prompt: str, formatted: bool = False, **kwargs):
55
+ """Devuelve un generador incremental de CompletionResponse."""
56
+ stream = self._model.generate_content(str(prompt), stream=True)
57
+
58
+ def gen():
59
+ acc = ""
60
+ from llama_index.core.llms import CompletionResponse # import local
61
+ for chunk in stream:
62
+ delta = getattr(chunk, "text", "") or (
63
+ chunk.parts[0].text if getattr(chunk, "parts", None) else ""
64
+ )
65
+ if delta:
66
+ acc += delta
67
+ yield CompletionResponse(text=acc, delta=delta)
68
+
69
+ return gen()
70
+
71
+ async def astream_complete(self, prompt: str, formatted: bool = False, **kwargs):
72
+ # ejecuta la versión síncrona en un hilo
73
+ sync_gen = await asyncio.to_thread(self.stream_complete, prompt, formatted=formatted, **kwargs)
74
+
75
+ async def async_gen():
76
+ for item in sync_gen:
77
+ yield item
78
+
79
+ return async_gen()
80
+
81
+ def stream_chat(self, messages: list[ChatMessage], **kwargs):
82
+ hist = [
83
+ {"role": "user" if m.role == "user" else "model", "parts": [{"text": str(m.content)}]}
84
+ for m in messages[:-1]
85
+ ]
86
+ last = str(messages[-1].content)
87
+ session = self._model.start_chat(history=hist)
88
+ stream = session.send_message(last, stream=True)
89
 
90
+ def gen():
91
+ acc = ""
92
+ for chunk in stream:
93
+ delta = getattr(chunk, "text", "") or (
94
+ chunk.parts[0].text if getattr(chunk, "parts", None) else ""
95
+ )
96
+ if delta:
97
+ acc += delta
98
+ yield ChatMessage(role="assistant", content=acc, additional_kwargs={"delta": delta})
99
 
100
+ return gen()
 
101
 
102
+ async def astream_chat(self, messages: list[ChatMessage], **kwargs):
103
+ sync_gen = await asyncio.to_thread(self.stream_chat, messages, **kwargs)
 
 
 
 
104
 
105
+ async def async_gen():
106
+ for item in sync_gen:
107
+ yield item
108
+
109
+ return async_gen()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
 
113
  # ---------- TOOLING ----------