Update backend/agent_instance.py
Browse files- backend/agent_instance.py +13 -4
backend/agent_instance.py
CHANGED
|
@@ -4,21 +4,30 @@ import os
|
|
| 4 |
# β
Add src to Python path
|
| 5 |
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "src")))
|
| 6 |
|
| 7 |
-
from txagent.txagent import TxAgent
|
| 8 |
|
| 9 |
def init_agent():
|
| 10 |
-
# β
Use
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
os.makedirs(model_cache_dir, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
| 13 |
os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
|
| 14 |
os.environ["HF_HOME"] = model_cache_dir
|
| 15 |
|
|
|
|
| 16 |
model_name = "mims-harvard/TxAgent-T1-Llama-3.1-8B"
|
| 17 |
rag_model_name = "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B"
|
| 18 |
tool_files_dict = {
|
| 19 |
-
"new_tool": os.path.
|
| 20 |
}
|
| 21 |
|
|
|
|
| 22 |
agent = TxAgent(
|
| 23 |
model_name=model_name,
|
| 24 |
rag_model_name=rag_model_name,
|
|
|
|
| 4 |
# β
Add src to Python path
|
| 5 |
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "src")))
|
| 6 |
|
| 7 |
+
from txagent.txagent import TxAgent
|
| 8 |
|
| 9 |
def init_agent():
|
| 10 |
+
# β
Use Hugging Face persistent storage
|
| 11 |
+
base_dir = "/data"
|
| 12 |
+
model_cache_dir = os.path.join(base_dir, "hf_cache")
|
| 13 |
+
tool_cache_dir = os.path.join(base_dir, "tool_cache")
|
| 14 |
+
|
| 15 |
+
# β
Ensure the folders exist
|
| 16 |
os.makedirs(model_cache_dir, exist_ok=True)
|
| 17 |
+
os.makedirs(tool_cache_dir, exist_ok=True)
|
| 18 |
+
|
| 19 |
+
# β
Set environment variables so models stay cached after restart
|
| 20 |
os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
|
| 21 |
os.environ["HF_HOME"] = model_cache_dir
|
| 22 |
|
| 23 |
+
# β
Paths to model + tool definitions
|
| 24 |
model_name = "mims-harvard/TxAgent-T1-Llama-3.1-8B"
|
| 25 |
rag_model_name = "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B"
|
| 26 |
tool_files_dict = {
|
| 27 |
+
"new_tool": os.path.join(tool_cache_dir, "new_tool.json")
|
| 28 |
}
|
| 29 |
|
| 30 |
+
# β
Init agent with config
|
| 31 |
agent = TxAgent(
|
| 32 |
model_name=model_name,
|
| 33 |
rag_model_name=rag_model_name,
|