jedick commited on
Commit
8cb166c
·
1 Parent(s): 9477d3a

Download model before running workflow

Browse files
Files changed (2) hide show
  1. app.py +6 -1
  2. main.py +8 -0
app.py CHANGED
@@ -4,7 +4,7 @@ from graph import BuildGraph
4
  from retriever import db_dir
5
  from langgraph.checkpoint.memory import MemorySaver
6
  from dotenv import load_dotenv
7
- from main import openai_model, model_id
8
  from util import get_sources, get_start_end_months
9
  from mods.tool_calling_llm import extract_think
10
  import requests
@@ -211,6 +211,11 @@ def to_workflow(request: gr.Request, *args):
211
  # Add session_hash to arguments
212
  new_args = args + (request.session_hash,)
213
  if compute_mode == "local":
 
 
 
 
 
214
  for value in run_workflow_local(*new_args):
215
  yield value
216
  if compute_mode == "remote":
 
4
  from retriever import db_dir
5
  from langgraph.checkpoint.memory import MemorySaver
6
  from dotenv import load_dotenv
7
+ from main import openai_model, model_id, DownloadChatModel
8
  from util import get_sources, get_start_end_months
9
  from mods.tool_calling_llm import extract_think
10
  import requests
 
211
  # Add session_hash to arguments
212
  new_args = args + (request.session_hash,)
213
  if compute_mode == "local":
214
+ # If graph hasn't been instantiated, download model before running workflow
215
+ graph = graph_instances[compute_mode].get(request.session_hash)
216
+ if graph is None:
217
+ DownloadChatModel()
218
+ # Call the workflow function with the @spaces.GPU decorator
219
  for value in run_workflow_local(*new_args):
220
  yield value
221
  if compute_mode == "remote":
main.py CHANGED
@@ -5,6 +5,7 @@ from langchain_core.output_parsers import StrOutputParser
5
  from langgraph.checkpoint.memory import MemorySaver
6
  from langchain_core.messages import ToolMessage
7
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
 
8
  from datetime import datetime
9
  from dotenv import load_dotenv
10
  import os
@@ -128,6 +129,13 @@ def ProcessDirectory(path, compute_mode):
128
  print(f"Chroma: no change for {file_path}")
129
 
130
 
 
 
 
 
 
 
 
131
  def GetChatModel(compute_mode):
132
  """
133
  Get a chat model.
 
5
  from langgraph.checkpoint.memory import MemorySaver
6
  from langchain_core.messages import ToolMessage
7
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
8
+ from huggingface_hub import snapshot_download
9
  from datetime import datetime
10
  from dotenv import load_dotenv
11
  import os
 
129
  print(f"Chroma: no change for {file_path}")
130
 
131
 
132
+ def DownloadChatModel():
133
+ """
134
+ Downloads a chat model to the local Hugging Face cache.
135
+ """
136
+ snapshot_download(model_id)
137
+
138
+
139
  def GetChatModel(compute_mode):
140
  """
141
  Get a chat model.