Spaces:
Sleeping
Sleeping
Update agent.py
Browse files
agent.py
CHANGED
|
@@ -7,7 +7,7 @@ import os
|
|
| 7 |
from dotenv import load_dotenv
|
| 8 |
from langgraph.checkpoint.memory import InMemorySaver
|
| 9 |
from langgraph.prebuilt import create_react_agent
|
| 10 |
-
from tools import (search_academic_source,
|
| 11 |
save_output, read_output, cluster_and_visualize,
|
| 12 |
get_pajais_taxonomy, read_pdf_text, enrich_doi,
|
| 13 |
read_word_text, import_from_scratch)
|
|
@@ -15,7 +15,7 @@ from prompt import build_prompt
|
|
| 15 |
|
| 16 |
load_dotenv()
|
| 17 |
|
| 18 |
-
TOOLS = [search_academic_source,
|
| 19 |
read_output, cluster_and_visualize, get_pajais_taxonomy,
|
| 20 |
read_pdf_text, enrich_doi, read_word_text, import_from_scratch]
|
| 21 |
|
|
@@ -23,7 +23,12 @@ def _llm():
|
|
| 23 |
provider = os.getenv("LLM_PROVIDER", "groq").lower()
|
| 24 |
if provider == "mistral":
|
| 25 |
from langchain_mistralai import ChatMistralAI
|
| 26 |
-
return ChatMistralAI(model=os.getenv("MISTRAL_BIG", "mistral-small-latest"), temperature=0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
from langchain_groq import ChatGroq
|
| 28 |
return ChatGroq(model=os.getenv("GROQ_BIG", "llama-3.3-70b-versatile"), temperature=0)
|
| 29 |
|
|
|
|
| 7 |
from dotenv import load_dotenv
|
| 8 |
from langgraph.checkpoint.memory import InMemorySaver
|
| 9 |
from langgraph.prebuilt import create_react_agent
|
| 10 |
+
from tools import (search_academic_source, save_papers, get_paper_batch,
|
| 11 |
save_output, read_output, cluster_and_visualize,
|
| 12 |
get_pajais_taxonomy, read_pdf_text, enrich_doi,
|
| 13 |
read_word_text, import_from_scratch)
|
|
|
|
| 15 |
|
| 16 |
load_dotenv()
|
| 17 |
|
| 18 |
+
TOOLS = [search_academic_source, save_papers, get_paper_batch, save_output,
|
| 19 |
read_output, cluster_and_visualize, get_pajais_taxonomy,
|
| 20 |
read_pdf_text, enrich_doi, read_word_text, import_from_scratch]
|
| 21 |
|
|
|
|
| 23 |
provider = os.getenv("LLM_PROVIDER", "groq").lower()
|
| 24 |
if provider == "mistral":
|
| 25 |
from langchain_mistralai import ChatMistralAI
|
| 26 |
+
return ChatMistralAI(model=os.getenv("MISTRAL_BIG", "mistral-small-latest"), temperature=0)
|
| 27 |
+
elif provider == "huggingface":
|
| 28 |
+
from langchain_huggingface import HuggingFaceEndpoint
|
| 29 |
+
repo_id = os.getenv("HF_MODEL_ID", "mistralai/Mistral-7B-Instruct-v0.3")
|
| 30 |
+
token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
|
| 31 |
+
return HuggingFaceEndpoint(repo_id=repo_id, huggingfacehub_api_token=token, temperature=0.01)
|
| 32 |
from langchain_groq import ChatGroq
|
| 33 |
return ChatGroq(model=os.getenv("GROQ_BIG", "llama-3.3-70b-versatile"), temperature=0)
|
| 34 |
|