arjeet commited on
Commit
5c41cb8
·
1 Parent(s): 5ba223e

inference update

Browse files
Files changed (2) hide show
  1. inference.py +18 -6
  2. server/app.py +0 -2
inference.py CHANGED
@@ -5,15 +5,27 @@ from openai import OpenAI
5
  from server.cust_env_environment import DocSweeperEnvironment
6
  from models import DocAction
7
 
8
- def run_inference(task_name: str):
9
 
10
- api_base_url = os.environ.get("API_BASE_URL")
11
- model_name = os.environ.get("MODEL_NAME")
12
- hf_token = os.environ.get("HF_TOKEN")
 
 
 
 
 
 
13
 
14
- if not all([api_base_url, model_name, hf_token]):
15
- raise ValueError("Missing required environment variables: API_BASE_URL, MODEL_NAME, or HF_TOKEN")
 
16
 
 
 
 
 
 
 
17
 
18
  client = OpenAI(
19
  api_key=hf_token,
 
5
  from server.cust_env_environment import DocSweeperEnvironment
6
  from models import DocAction
7
 
 
8
 
9
+ IMAGE_NAME = os.getenv("IMAGE_NAME")
10
+ API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
11
+
12
+ API_BASE_URL = os.getenv("API_BASE_URL") or "https://router.huggingface.co/v1"
13
+ MODEL_NAME = os.getenv("MODEL_NAME") or "Qwen/Qwen2.5-72B-Instruct"
14
+ TASK_NAME = os.getenv("MY_ENV_V4_TASK", "echo")
15
+ BENCHMARK = os.getenv("MY_ENV_V4_BENCHMARK", "my_env_v4")
16
+
17
+ def run_inference(task_name: str):
18
 
19
+ api_base_url = os.environ.get("API_BASE_URL") or API_BASE_URL
20
+ model_name = os.environ.get("MODEL_NAME") or MODEL_NAME
21
+ hf_token = os.environ.get("HF_TOKEN") or API_KEY
22
 
23
+ if not api_base_url:
24
+ raise ValueError("Missinh api base url")
25
+ if not model_name:
26
+ raise ValueError("Missing model name")
27
+ if not hf_token:
28
+ raise ValueError("Missing hf_token")
29
 
30
  client = OpenAI(
31
  api_key=hf_token,
server/app.py CHANGED
@@ -10,8 +10,6 @@ import uvicorn
10
  from models import DocAction, DocObservation
11
  from .cust_env_environment import DocSweeperEnvironment
12
 
13
- # Create the FastAPI app
14
- # Pass the class (factory) instead of an instance for WebSocket session support
15
  app = create_app(DocSweeperEnvironment, DocAction, DocObservation, env_name="doc_sweeper")
16
 
17
  def main():
 
10
  from models import DocAction, DocObservation
11
  from .cust_env_environment import DocSweeperEnvironment
12
 
 
 
13
  app = create_app(DocSweeperEnvironment, DocAction, DocObservation, env_name="doc_sweeper")
14
 
15
  def main():