abtsousa commited on
Commit
60d1fd6
·
1 Parent(s): 9465031

Replace ChatOpenAI with ChatDeepSeek in model.

Browse files
Files changed (1) hide show
  1. agent/nodes.py +18 -11
agent/nodes.py CHANGED
@@ -8,6 +8,7 @@ from langchain_core.messages import BaseMessage
8
  from langchain_core.language_models.base import LanguageModelInput
9
  from langchain_google_genai import ChatGoogleGenerativeAI
10
  from langchain_openai import ChatOpenAI
 
11
  from pydantic import BaseModel, Field, SecretStr
12
  from agent.prompts import get_system_prompt
13
  from agent.state import State
@@ -38,19 +39,25 @@ def _get_model() -> BaseChatModel:
38
  # model="gemini-2.5-pro"
39
  # )
40
 
41
- api_key = os.getenv(API_KEY_ENV_VAR)
42
 
43
- return ChatOpenAI(
44
- api_key=SecretStr(api_key) if api_key else None,
45
- base_url=API_BASE_URL,
46
- model=MODEL_NAME,
 
 
 
 
 
 
 
 
 
 
 
47
  temperature=MODEL_TEMPERATURE if MODEL_TEMPERATURE else 0.0,
48
- rate_limiter=_rate_limiter,
49
- metadata={
50
- "reasoning": {
51
- "effort": "high" # Use high reasoning effort
52
- }
53
- }
54
  )
55
 
56
  def _get_tools() -> list[BaseTool]:
 
8
  from langchain_core.language_models.base import LanguageModelInput
9
  from langchain_google_genai import ChatGoogleGenerativeAI
10
  from langchain_openai import ChatOpenAI
11
+ from langchain_deepseek import ChatDeepSeek
12
  from pydantic import BaseModel, Field, SecretStr
13
  from agent.prompts import get_system_prompt
14
  from agent.state import State
 
39
  # model="gemini-2.5-pro"
40
  # )
41
 
42
+ # api_key = os.getenv(API_KEY_ENV_VAR)
43
 
44
+ # return ChatOpenAI(
45
+ # api_key=SecretStr(api_key) if api_key else None,
46
+ # base_url=API_BASE_URL,
47
+ # model=MODEL_NAME,
48
+ # temperature=MODEL_TEMPERATURE if MODEL_TEMPERATURE else 0.0,
49
+ # rate_limiter=_rate_limiter,
50
+ # metadata={
51
+ # "reasoning": {
52
+ # "effort": "high" # Use high reasoning effort
53
+ # }
54
+ # }
55
+ # )
56
+
57
+ return ChatDeepSeek(
58
+ model="deepseek-chat",
59
  temperature=MODEL_TEMPERATURE if MODEL_TEMPERATURE else 0.0,
60
+ max_retries=2,
 
 
 
 
 
61
  )
62
 
63
  def _get_tools() -> list[BaseTool]: