shahidshaikh commited on
Commit
a5cc352
·
verified ·
1 Parent(s): 06414ba

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +8 -3
agent.py CHANGED
@@ -7,7 +7,7 @@ import os
7
  from dotenv import load_dotenv
8
  from langgraph.checkpoint.memory import InMemorySaver
9
  from langgraph.prebuilt import create_react_agent
10
- from tools import (search_academic_source, download_pdf, save_papers, get_paper_batch,
11
  save_output, read_output, cluster_and_visualize,
12
  get_pajais_taxonomy, read_pdf_text, enrich_doi,
13
  read_word_text, import_from_scratch)
@@ -15,7 +15,7 @@ from prompt import build_prompt
15
 
16
  load_dotenv()
17
 
18
- TOOLS = [search_academic_source, download_pdf, save_papers, get_paper_batch, save_output,
19
  read_output, cluster_and_visualize, get_pajais_taxonomy,
20
  read_pdf_text, enrich_doi, read_word_text, import_from_scratch]
21
 
@@ -23,7 +23,12 @@ def _llm():
23
  provider = os.getenv("LLM_PROVIDER", "groq").lower()
24
  if provider == "mistral":
25
  from langchain_mistralai import ChatMistralAI
26
- return ChatMistralAI(model=os.getenv("MISTRAL_BIG", "mistral-small-latest"), temperature=0, max_retries=6)
 
 
 
 
 
27
  from langchain_groq import ChatGroq
28
  return ChatGroq(model=os.getenv("GROQ_BIG", "llama-3.3-70b-versatile"), temperature=0)
29
 
 
7
  from dotenv import load_dotenv
8
  from langgraph.checkpoint.memory import InMemorySaver
9
  from langgraph.prebuilt import create_react_agent
10
+ from tools import (search_academic_source, save_papers, get_paper_batch,
11
  save_output, read_output, cluster_and_visualize,
12
  get_pajais_taxonomy, read_pdf_text, enrich_doi,
13
  read_word_text, import_from_scratch)
 
15
 
16
  load_dotenv()
17
 
18
+ TOOLS = [search_academic_source, save_papers, get_paper_batch, save_output,
19
  read_output, cluster_and_visualize, get_pajais_taxonomy,
20
  read_pdf_text, enrich_doi, read_word_text, import_from_scratch]
21
 
 
23
  provider = os.getenv("LLM_PROVIDER", "groq").lower()
24
  if provider == "mistral":
25
  from langchain_mistralai import ChatMistralAI
26
+ return ChatMistralAI(model=os.getenv("MISTRAL_BIG", "mistral-small-latest"), temperature=0)
27
+ elif provider == "huggingface":
28
+ from langchain_huggingface import HuggingFaceEndpoint
29
+ repo_id = os.getenv("HF_MODEL_ID", "mistralai/Mistral-7B-Instruct-v0.3")
30
+ token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
31
+ return HuggingFaceEndpoint(repo_id=repo_id, huggingfacehub_api_token=token, temperature=0.01)
32
  from langchain_groq import ChatGroq
33
  return ChatGroq(model=os.getenv("GROQ_BIG", "llama-3.3-70b-versatile"), temperature=0)
34