adding smolagent
Browse files
app.py
CHANGED
|
@@ -11,6 +11,9 @@ from sentence_transformers import SentenceTransformer, util, CrossEncoder
|
|
| 11 |
from langchain.llms.base import LLM
|
| 12 |
import google.generativeai as genai
|
| 13 |
|
|
|
|
|
|
|
|
|
|
| 14 |
###############################################################################
|
| 15 |
# 1) Logging Setup
|
| 16 |
###############################################################################
|
|
@@ -199,7 +202,36 @@ class QuestionSanityChecker:
|
|
| 199 |
sanity_checker = QuestionSanityChecker(llm)
|
| 200 |
|
| 201 |
###############################################################################
|
| 202 |
-
# 7)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
###############################################################################
|
| 204 |
class AnswerExpander:
|
| 205 |
def __init__(self, llm: GeminiLLM):
|
|
@@ -230,7 +262,7 @@ class AnswerExpander:
|
|
| 230 |
answer_expander = AnswerExpander(llm)
|
| 231 |
|
| 232 |
###############################################################################
|
| 233 |
-
#
|
| 234 |
###############################################################################
|
| 235 |
def handle_query(query: str) -> str:
|
| 236 |
if not query or not isinstance(query, str) or len(query.strip()) == 0:
|
|
@@ -247,14 +279,26 @@ def handle_query(query: str) -> str:
|
|
| 247 |
if not retrieved:
|
| 248 |
return "I'm sorry, I couldn't find an answer to your question."
|
| 249 |
|
| 250 |
-
#
|
| 251 |
top_score = retrieved[0][1] # Assuming the list is sorted descending
|
| 252 |
similarity_threshold = 0.3 # Adjust this threshold based on empirical results
|
| 253 |
|
| 254 |
if top_score < similarity_threshold:
|
| 255 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 256 |
|
| 257 |
-
# Proceed with answer expansion
|
| 258 |
responses = [ans[0] for ans in retrieved]
|
| 259 |
expanded_answer = answer_expander.expand(query, responses)
|
| 260 |
return expanded_answer
|
|
@@ -264,18 +308,12 @@ def handle_query(query: str) -> str:
|
|
| 264 |
return "An error occurred while processing your request."
|
| 265 |
|
| 266 |
###############################################################################
|
| 267 |
-
#
|
| 268 |
###############################################################################
|
| 269 |
def gradio_interface(query: str):
|
| 270 |
try:
|
| 271 |
response = handle_query(query)
|
| 272 |
-
formatted_response =
|
| 273 |
-
f"**Daily Wellness AI**\n\n"
|
| 274 |
-
f"{response}\n\n"
|
| 275 |
-
"Disclaimer: This is general wellness information, "
|
| 276 |
-
"not a substitute for professional medical advice.\n\n"
|
| 277 |
-
"Wishing you a calm and wonderful day!"
|
| 278 |
-
)
|
| 279 |
return formatted_response
|
| 280 |
except Exception as e:
|
| 281 |
logger.error(f"Error in Gradio interface: {e}")
|
|
@@ -296,13 +334,14 @@ interface = gr.Interface(
|
|
| 296 |
examples=[
|
| 297 |
"What is box breathing and how does it help reduce anxiety?",
|
| 298 |
"Provide a daily wellness schedule incorporating box breathing techniques.",
|
| 299 |
-
"What are some tips for maintaining good posture while working at a desk?"
|
|
|
|
| 300 |
],
|
| 301 |
allow_flagging="never"
|
| 302 |
)
|
| 303 |
|
| 304 |
###############################################################################
|
| 305 |
-
#
|
| 306 |
###############################################################################
|
| 307 |
if __name__ == "__main__":
|
| 308 |
try:
|
|
|
|
| 11 |
from langchain.llms.base import LLM
|
| 12 |
import google.generativeai as genai
|
| 13 |
|
| 14 |
+
# Import smolagents components
|
| 15 |
+
from smolagents import CodeAgent, LiteLLMModel, DuckDuckGoSearchTool, ManagedAgent
|
| 16 |
+
|
| 17 |
###############################################################################
|
| 18 |
# 1) Logging Setup
|
| 19 |
###############################################################################
|
|
|
|
| 202 |
sanity_checker = QuestionSanityChecker(llm)
|
| 203 |
|
| 204 |
###############################################################################
|
| 205 |
+
# 7) smolagents Integration: GROQ Model and Web Search
|
| 206 |
+
###############################################################################
|
| 207 |
+
# Initialize the smolagents' LiteLLMModel with GROQ model
|
| 208 |
+
smol_model = LiteLLMModel("groq/llama3-8b-8192")
|
| 209 |
+
|
| 210 |
+
# Instantiate the DuckDuckGo search tool
|
| 211 |
+
search_tool = DuckDuckGoSearchTool()
|
| 212 |
+
|
| 213 |
+
# Create the web agent with the search tool
|
| 214 |
+
web_agent = CodeAgent(
|
| 215 |
+
tools=[search_tool],
|
| 216 |
+
model=smol_model
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
# Define the managed web agent
|
| 220 |
+
managed_web_agent = ManagedAgent(
|
| 221 |
+
agent=web_agent,
|
| 222 |
+
name="web_search",
|
| 223 |
+
description="Runs a web search for you. Provide your query as an argument."
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
# Create the manager agent with managed web agent and additional tools if needed
|
| 227 |
+
manager_agent = CodeAgent(
|
| 228 |
+
tools=[], # Add additional tools here if required
|
| 229 |
+
model=smol_model,
|
| 230 |
+
managed_agents=[managed_web_agent]
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
###############################################################################
|
| 234 |
+
# 8) Answer Expansion
|
| 235 |
###############################################################################
|
| 236 |
class AnswerExpander:
|
| 237 |
def __init__(self, llm: GeminiLLM):
|
|
|
|
| 262 |
answer_expander = AnswerExpander(llm)
|
| 263 |
|
| 264 |
###############################################################################
|
| 265 |
+
# 9) Query Handling
|
| 266 |
###############################################################################
|
| 267 |
def handle_query(query: str) -> str:
|
| 268 |
if not query or not isinstance(query, str) or len(query.strip()) == 0:
|
|
|
|
| 279 |
if not retrieved:
|
| 280 |
return "I'm sorry, I couldn't find an answer to your question."
|
| 281 |
|
| 282 |
+
# Check similarity threshold
|
| 283 |
top_score = retrieved[0][1] # Assuming the list is sorted descending
|
| 284 |
similarity_threshold = 0.3 # Adjust this threshold based on empirical results
|
| 285 |
|
| 286 |
if top_score < similarity_threshold:
|
| 287 |
+
# Perform web search using manager_agent
|
| 288 |
+
logger.info("Similarity score below threshold. Performing web search.")
|
| 289 |
+
web_search_response = manager_agent.run(query)
|
| 290 |
+
logger.debug(f"Web search response: {web_search_response}")
|
| 291 |
+
|
| 292 |
+
# Optionally, process the web_search_response if needed
|
| 293 |
+
# For simplicity, return the web search response directly
|
| 294 |
+
return (
|
| 295 |
+
f"**Daily Wellness AI**\n\n"
|
| 296 |
+
f"{web_search_response}\n\n"
|
| 297 |
+
"Disclaimer: This information is retrieved from the web and is not a substitute for professional medical advice.\n\n"
|
| 298 |
+
"Wishing you a calm and wonderful day!"
|
| 299 |
+
)
|
| 300 |
|
| 301 |
+
# Proceed with answer expansion using retrieved_answers
|
| 302 |
responses = [ans[0] for ans in retrieved]
|
| 303 |
expanded_answer = answer_expander.expand(query, responses)
|
| 304 |
return expanded_answer
|
|
|
|
| 308 |
return "An error occurred while processing your request."
|
| 309 |
|
| 310 |
###############################################################################
|
| 311 |
+
# 10) Gradio Interface
|
| 312 |
###############################################################################
|
| 313 |
def gradio_interface(query: str):
|
| 314 |
try:
|
| 315 |
response = handle_query(query)
|
| 316 |
+
formatted_response = response # Response is already formatted
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 317 |
return formatted_response
|
| 318 |
except Exception as e:
|
| 319 |
logger.error(f"Error in Gradio interface: {e}")
|
|
|
|
| 334 |
examples=[
|
| 335 |
"What is box breathing and how does it help reduce anxiety?",
|
| 336 |
"Provide a daily wellness schedule incorporating box breathing techniques.",
|
| 337 |
+
"What are some tips for maintaining good posture while working at a desk?",
|
| 338 |
+
"Who is the CEO of Hugging Face?" # Example of an out-of-context question
|
| 339 |
],
|
| 340 |
allow_flagging="never"
|
| 341 |
)
|
| 342 |
|
| 343 |
###############################################################################
|
| 344 |
+
# 11) Launch Gradio
|
| 345 |
###############################################################################
|
| 346 |
if __name__ == "__main__":
|
| 347 |
try:
|