nathanael-fijalkow commited on
Commit
615a63b
·
1 Parent(s): 7a36b3c

Updated agent with local option

Browse files
Files changed (1) hide show
  1. agent.py +33 -7
agent.py CHANGED
@@ -35,6 +35,10 @@ from huggingface_hub import InferenceClient
35
  # Load environment variables
36
  load_dotenv()
37
 
 
 
 
 
38
  # =============================================================================
39
  # LLM Configuration - DO NOT MODIFY
40
  # =============================================================================
@@ -42,12 +46,25 @@ load_dotenv()
42
  # Model to use (fixed for fair evaluation)
43
  LLM_MODEL = "Qwen/Qwen2.5-72B-Instruct"
44
 
45
- # Initialize the LLM client (uses HF_TOKEN from environment)
46
- _hf_token = os.getenv("HF_TOKEN")
47
- if not _hf_token:
48
- raise ValueError("HF_TOKEN not found. Set it in your .env file.")
 
 
49
 
50
- LLM_CLIENT = InferenceClient(token=_hf_token)
 
 
 
 
 
 
 
 
 
 
 
51
 
52
 
53
  def call_llm(prompt: str, system_prompt: str, seed: int, max_tokens: int = 300) -> str:
@@ -74,7 +91,16 @@ def call_llm(prompt: str, system_prompt: str, seed: int, max_tokens: int = 300)
74
  {"role": "system", "content": system_prompt},
75
  {"role": "user", "content": prompt},
76
  ]
77
-
 
 
 
 
 
 
 
 
 
78
  response = LLM_CLIENT.chat.completions.create(
79
  model=LLM_MODEL,
80
  messages=messages,
@@ -82,7 +108,7 @@ def call_llm(prompt: str, system_prompt: str, seed: int, max_tokens: int = 300)
82
  max_tokens=max_tokens,
83
  seed=seed,
84
  )
85
-
86
  return response.choices[0].message.content
87
 
88
 
 
35
  # Load environment variables
36
  load_dotenv()
37
 
38
+ # Set USE_LOCAL_MODEL=1 in your .env to use a locally downloaded model
39
+ USE_LOCAL_MODEL = os.getenv("USE_LOCAL_MODEL", "0").strip() in ("1", "true", "yes")
40
+ LOCAL_MODEL_ID = os.getenv("LOCAL_MODEL_ID", "Qwen/Qwen2.5-3B-Instruct")
41
+
42
  # =============================================================================
43
  # LLM Configuration - DO NOT MODIFY
44
  # =============================================================================
 
46
  # Model to use (fixed for fair evaluation)
47
  LLM_MODEL = "Qwen/Qwen2.5-72B-Instruct"
48
 
49
+ # Initialize the LLM client based on mode
50
+ _local_pipeline = None
51
+
52
+ if USE_LOCAL_MODEL:
53
+ import torch
54
+ from transformers import pipeline as _hf_pipeline
55
 
56
+ _local_pipeline = _hf_pipeline(
57
+ "text-generation",
58
+ model=LOCAL_MODEL_ID,
59
+ torch_dtype=torch.bfloat16,
60
+ device_map="auto",
61
+ )
62
+ LLM_CLIENT = None
63
+ else:
64
+ _hf_token = os.getenv("HF_TOKEN")
65
+ if not _hf_token:
66
+ raise ValueError("HF_TOKEN not found. Set it in your .env file.")
67
+ LLM_CLIENT = InferenceClient(token=_hf_token)
68
 
69
 
70
  def call_llm(prompt: str, system_prompt: str, seed: int, max_tokens: int = 300) -> str:
 
91
  {"role": "system", "content": system_prompt},
92
  {"role": "user", "content": prompt},
93
  ]
94
+
95
+ if USE_LOCAL_MODEL and _local_pipeline is not None:
96
+ outputs = _local_pipeline(
97
+ messages,
98
+ max_new_tokens=max_tokens,
99
+ temperature=0.0001, # Near-deterministic (0.0 unsupported by some backends)
100
+ do_sample=True,
101
+ )
102
+ return outputs[0]["generated_text"][-1]["content"]
103
+
104
  response = LLM_CLIENT.chat.completions.create(
105
  model=LLM_MODEL,
106
  messages=messages,
 
108
  max_tokens=max_tokens,
109
  seed=seed,
110
  )
111
+
112
  return response.choices[0].message.content
113
 
114