Spaces:
Sleeping
Sleeping
update
Browse files
agent.py
CHANGED
|
@@ -4,6 +4,7 @@ import time
|
|
| 4 |
from typing import Optional
|
| 5 |
|
| 6 |
from langchain.chat_models import init_chat_model
|
|
|
|
| 7 |
from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
|
| 8 |
from langchain_community.tools import TavilySearchResults
|
| 9 |
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage, AnyMessage
|
|
@@ -11,7 +12,7 @@ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
|
| 11 |
from langgraph.graph import add_messages, START, END, StateGraph
|
| 12 |
from langchain_core.tools import tool
|
| 13 |
from langgraph.prebuilt import ToolNode
|
| 14 |
-
|
| 15 |
|
| 16 |
from typing_extensions import TypedDict, Annotated
|
| 17 |
|
|
@@ -30,6 +31,13 @@ def get_llm():
|
|
| 30 |
#return init_chat_model("llama-3.3-70b-versatile", model_provider="groq")
|
| 31 |
|
| 32 |
return init_chat_model("gemini-2.0-flash", model_provider="google_genai")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
def get_graph(llm):
|
| 35 |
with open('prompts/system_prompt.md', 'r', encoding='utf-8') as markdown_file:
|
|
@@ -207,6 +215,8 @@ def get_graph(llm):
|
|
| 207 |
|
| 208 |
def call_model(state: State):
|
| 209 |
print("\n-------------------- Agent has been called -----------------------------------\n")
|
|
|
|
|
|
|
| 210 |
# get all messages from the state
|
| 211 |
messages = state["messages"]
|
| 212 |
# append instruction message
|
|
@@ -216,12 +226,14 @@ def get_graph(llm):
|
|
| 216 |
# invoke LLM
|
| 217 |
response = llm_with_tools.invoke(prompt_answer)
|
| 218 |
print("Agent has made a decision:\n", response.content, response.tool_calls)
|
| 219 |
-
|
| 220 |
-
time.sleep(4)
|
| 221 |
|
| 222 |
return {"messages": [response], "aggregate": ["Agent"]}
|
| 223 |
|
| 224 |
def get_answer(state: State):
|
|
|
|
|
|
|
|
|
|
| 225 |
# get all messages from the state
|
| 226 |
messages = state["messages"]
|
| 227 |
# add prompt message
|
|
|
|
| 4 |
from typing import Optional
|
| 5 |
|
| 6 |
from langchain.chat_models import init_chat_model
|
| 7 |
+
|
| 8 |
from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
|
| 9 |
from langchain_community.tools import TavilySearchResults
|
| 10 |
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage, AnyMessage
|
|
|
|
| 12 |
from langgraph.graph import add_messages, START, END, StateGraph
|
| 13 |
from langchain_core.tools import tool
|
| 14 |
from langgraph.prebuilt import ToolNode
|
| 15 |
+
from pydantic import SecretStr
|
| 16 |
|
| 17 |
from typing_extensions import TypedDict, Annotated
|
| 18 |
|
|
|
|
| 31 |
#return init_chat_model("llama-3.3-70b-versatile", model_provider="groq")
|
| 32 |
|
| 33 |
return init_chat_model("gemini-2.0-flash", model_provider="google_genai")
|
| 34 |
+
#return AzureChatOpenAI(
|
| 35 |
+
# api_key=SecretStr(os.environ["AZURE_OPENAI_API_KEY"]),
|
| 36 |
+
# azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT"],
|
| 37 |
+
#azure_deployment="gpt-4o-mini",
|
| 38 |
+
#api_version=os.environ["AZURE_OPENAI_API_VERSION"],
|
| 39 |
+
#)
|
| 40 |
+
|
| 41 |
|
| 42 |
def get_graph(llm):
|
| 43 |
with open('prompts/system_prompt.md', 'r', encoding='utf-8') as markdown_file:
|
|
|
|
| 215 |
|
| 216 |
def call_model(state: State):
|
| 217 |
print("\n-------------------- Agent has been called -----------------------------------\n")
|
| 218 |
+
print("Waiting for 5 seconds...")
|
| 219 |
+
time.sleep(5)
|
| 220 |
# get all messages from the state
|
| 221 |
messages = state["messages"]
|
| 222 |
# append instruction message
|
|
|
|
| 226 |
# invoke LLM
|
| 227 |
response = llm_with_tools.invoke(prompt_answer)
|
| 228 |
print("Agent has made a decision:\n", response.content, response.tool_calls)
|
| 229 |
+
|
|
|
|
| 230 |
|
| 231 |
return {"messages": [response], "aggregate": ["Agent"]}
|
| 232 |
|
| 233 |
def get_answer(state: State):
|
| 234 |
+
print("\n-------------------- Generating Answer -----------------------------------\n")
|
| 235 |
+
print("Waiting for 5 seconds...")
|
| 236 |
+
time.sleep(5)
|
| 237 |
# get all messages from the state
|
| 238 |
messages = state["messages"]
|
| 239 |
# add prompt message
|
app.py
CHANGED
|
@@ -126,8 +126,8 @@ def run_and_submit_all( profile: gr.OAuthProfile | None):
|
|
| 126 |
answers_payload = []
|
| 127 |
print(f"Running agent on {len(questions_data)} questions...")
|
| 128 |
for item in questions_data:
|
| 129 |
-
print("Waiting for
|
| 130 |
-
time.sleep(
|
| 131 |
task_id = item.get("task_id")
|
| 132 |
question_text = item.get("question")
|
| 133 |
content_type = None
|
|
|
|
| 126 |
answers_payload = []
|
| 127 |
print(f"Running agent on {len(questions_data)} questions...")
|
| 128 |
for item in questions_data:
|
| 129 |
+
print("Waiting for 5 seconds...")
|
| 130 |
+
time.sleep(5)
|
| 131 |
task_id = item.get("task_id")
|
| 132 |
question_text = item.get("question")
|
| 133 |
content_type = None
|