Tuan197 commited on
Commit
9029d3e
·
verified ·
1 Parent(s): e467991

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +22 -11
agent.py CHANGED
@@ -161,18 +161,29 @@ def build_graph(provider: str = "groq"):
161
  llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
162
  elif provider == "huggingface":
163
  # TODO: Add huggingface endpoint
164
- # llm = ChatHuggingFace(
165
- # llm=HuggingFaceEndpoint(
166
- # repo_id="Meta-DeepLearning/llama-2-7b-chat-hf",
167
- # temperature=0
168
- # )
169
- # )
170
- llm = ChatHuggingFace(
171
- llm=HuggingFaceEndpoint(
172
- url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
173
- temperature=0,
174
- ),
175
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  else:
177
  raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
178
  # Bind tools to LLM
 
161
  llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
162
  elif provider == "huggingface":
163
  # TODO: Add huggingface endpoint
164
+ model_id = "Qwen/Qwen3-32B" # You can change this to any model you want
165
+
166
+ # Load tokenizer and model
167
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
168
+ model = AutoModelForCausalLM.from_pretrained(
169
+ model_id,
170
+ device_map="auto", # Automatically determine device placement
171
+ torch_dtype="auto" # Use appropriate precision for device
 
 
 
172
  )
173
+
174
+ # Create text-generation pipeline
175
+ pipe = pipeline(
176
+ "text-generation",
177
+ model=model,
178
+ tokenizer=tokenizer,
179
+ max_length=2048,
180
+ temperature=0,
181
+ top_p=0.95,
182
+ repetition_penalty=1.15
183
+ )
184
+
185
+ # Create LangChain HuggingFacePipeline object
186
+ llm = HuggingFacePipeline(pipeline=pipe)
187
  else:
188
  raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
189
  # Bind tools to LLM