Josedcape commited on
Commit
ee4bd26
·
verified ·
1 Parent(s): a400aef

Update src/utils/llm.py

Browse files
Files changed (1) hide show
  1. src/utils/llm.py +67 -101
src/utils/llm.py CHANGED
@@ -1,101 +1,67 @@
1
- from openai import OpenAI
2
- import pdb
3
- from langchain_openai import ChatOpenAI
4
- from langchain_core.globals import get_llm_cache
5
- from langchain_core.language_models.base import (
6
- BaseLanguageModel,
7
- LangSmithParams,
8
- LanguageModelInput,
9
- )
10
- from langchain_core.load import dumpd, dumps
11
- from langchain_core.messages import (
12
- AIMessage,
13
- SystemMessage,
14
- AnyMessage,
15
- BaseMessage,
16
- BaseMessageChunk,
17
- HumanMessage,
18
- convert_to_messages,
19
- message_chunk_to_message,
20
- )
21
- from langchain_core.outputs import (
22
- ChatGeneration,
23
- ChatGenerationChunk,
24
- ChatResult,
25
- LLMResult,
26
- RunInfo,
27
- )
28
- from langchain_core.output_parsers.base import OutputParserLike
29
- from langchain_core.runnables import Runnable, RunnableConfig
30
- from langchain_core.tools import BaseTool
31
-
32
- from typing import (
33
- TYPE_CHECKING,
34
- Any,
35
- Callable,
36
- Literal,
37
- Optional,
38
- Union,
39
- cast,
40
- )
41
-
42
- class DeepSeekR1ChatOpenAI(ChatOpenAI):
43
-
44
- def __init__(self, *args: Any, **kwargs: Any) -> None:
45
- super().__init__(*args, **kwargs)
46
- self.client = OpenAI(
47
- base_url=kwargs.get("base_url"),
48
- api_key=kwargs.get("api_key")
49
- )
50
-
51
- async def ainvoke(
52
- self,
53
- input: LanguageModelInput,
54
- config: Optional[RunnableConfig] = None,
55
- *,
56
- stop: Optional[list[str]] = None,
57
- **kwargs: Any,
58
- ) -> AIMessage:
59
- message_history = []
60
- for input_ in input:
61
- if isinstance(input_, SystemMessage):
62
- message_history.append({"role": "system", "content": input_.content})
63
- elif isinstance(input_, AIMessage):
64
- message_history.append({"role": "assistant", "content": input_.content})
65
- else:
66
- message_history.append({"role": "user", "content": input_.content})
67
-
68
- response = self.client.chat.completions.create(
69
- model=self.model_name,
70
- messages=messages
71
- )
72
-
73
- reasoning_content = response.choices[0].message.reasoning_content
74
- content = response.choices[0].message.content
75
- return AIMessage(content=content, reasoning_content=reasoning_content)
76
-
77
- def invoke(
78
- self,
79
- input: LanguageModelInput,
80
- config: Optional[RunnableConfig] = None,
81
- *,
82
- stop: Optional[list[str]] = None,
83
- **kwargs: Any,
84
- ) -> AIMessage:
85
- message_history = []
86
- for input_ in input:
87
- if isinstance(input_, SystemMessage):
88
- message_history.append({"role": "system", "content": input_.content})
89
- elif isinstance(input_, AIMessage):
90
- message_history.append({"role": "assistant", "content": input_.content})
91
- else:
92
- message_history.append({"role": "user", "content": input_.content})
93
-
94
- response = self.client.chat.completions.create(
95
- model=self.model_name,
96
- messages=message_history
97
- )
98
-
99
- reasoning_content = response.choices[0].message.reasoning_content
100
- content = response.choices[0].message.content
101
- return AIMessage(content=content, reasoning_content=reasoning_content)
 
1
+ from openai import OpenAI
2
+ from langchain_openai import ChatOpenAI
3
+ from langchain_core.language_models.base import LanguageModelInput
4
+ from langchain_core.messages import (
5
+ AIMessage,
6
+ SystemMessage,
7
+ BaseMessage,
8
+ )
9
+ from typing import Any, Optional
10
+
11
+ class DeepSeekR1ChatOpenAI(ChatOpenAI):
12
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
13
+ super().__init__(*args, **kwargs)
14
+ self.client = OpenAI(
15
+ base_url=kwargs.get("base_url"),
16
+ api_key=kwargs.get("api_key")
17
+ )
18
+
19
+ async def ainvoke(
20
+ self,
21
+ input: LanguageModelInput,
22
+ config: Optional[dict] = None,
23
+ *,
24
+ stop: Optional[list[str]] = None,
25
+ **kwargs: Any,
26
+ ) -> AIMessage:
27
+ message_history = []
28
+ for input_ in input:
29
+ if isinstance(input_, SystemMessage):
30
+ message_history.append({"role": "system", "content": input_.content})
31
+ elif isinstance(input_, AIMessage):
32
+ message_history.append({"role": "assistant", "content": input_.content})
33
+ else:
34
+ message_history.append({"role": "user", "content": input_.content})
35
+
36
+ response = self.client.chat.completions.create(
37
+ model=self.model_name,
38
+ messages=message_history
39
+ )
40
+
41
+ content = response.choices[0].message.content
42
+ return AIMessage(content=content)
43
+
44
+ def invoke(
45
+ self,
46
+ input: LanguageModelInput,
47
+ config: Optional[dict] = None,
48
+ *,
49
+ stop: Optional[list[str]] = None,
50
+ **kwargs: Any,
51
+ ) -> AIMessage:
52
+ message_history = []
53
+ for input_ in input:
54
+ if isinstance(input_, SystemMessage):
55
+ message_history.append({"role": "system", "content": input_.content})
56
+ elif isinstance(input_, AIMessage):
57
+ message_history.append({"role": "assistant", "content": input_.content})
58
+ else:
59
+ message_history.append({"role": "user", "content": input_.content})
60
+
61
+ response = self.client.chat.completions.create(
62
+ model=self.model_name,
63
+ messages=message_history
64
+ )
65
+
66
+ content = response.choices[0].message.content
67
+ return AIMessage(content=content)