Berlin_roomates / agent_core.py
subba5076's picture
Update agent_core.py
191b350 verified
# agent_core.py
from smolagents import CodeAgent, WebSearchTool, InferenceClientModel, VisitWebpageTool, Tool, tool
import random
# -----------------------------
# MODEL SETUP (FREE + TOOL CALLING)
# -----------------------------
model = InferenceClientModel()
# -----------------------------
# MENU SUGGESTION TOOL
# -----------------------------
@tool
def suggest_menu(occasion: str) -> str:
casual_dishes = [
"Chicken 65",
"Chicken Biryani",
"Pepper Chicken Fry",
"Andhra Chicken Roast",
"Chicken Ghee Roast"
]
formal_dishes = [
"Hyderabadi Chicken Dum Biryani",
"Butter Chicken",
"Chicken Tandoori",
"Chicken Chettinad",
"Chicken Mughlai"
]
if occasion == "casual":
return f"Casual Menu Dish: {random.choice(casual_dishes)}"
elif occasion == "formal":
return f"Formal Menu Dish: {random.choice(formal_dishes)}"
else:
return "Please choose 'casual' or 'formal'."
# -----------------------------
# CATERING SERVICE TOOL
# -----------------------------
@tool
def catering_service_tool(query: str) -> str:
services = {
"Srikanth": 4.9,
"Nitish": 4.8,
"Subrahmanya": 4.7,
"Rahul": 4.6,
}
return max(services, key=services.get)
# -----------------------------
# INDIAN MOVIE THEME TOOL
# -----------------------------
class IndianMovieThemeTool(Tool):
name = "indian_movie_theme_generator"
description = "Suggests Indian movie/series-based party themes."
inputs = {
"category": {
"type": "string",
"description": "Theme type (e.g., 'mass', 'romantic', 'gangster', 'pan-india').",
}
}
output_type = "string"
def forward(self, category: str):
themes = {
"mass": "KGF Theme Night: Gold, black outfits, heavy bass music, and dramatic lighting.",
"romantic": "YJHD Theme: Colorful decor, travel vibes, and chill Bollywood playlist.",
"gangster": "Gangs of Wasseypur Theme: Rustic setup, desi snacks, and intense Bollywood tracks.",
"pan-india": "RRR Theme: Traditional outfits, energetic Telugu/Hindi songs, and festive lighting."
}
return themes.get(category.lower(), "Theme not found.")
# -----------------------------
# PLAYLIST TOOL
# -----------------------------
@tool
def playlist_tool(theme: str) -> str:
playlists = {
"KGF": ["Salaam Rocky Bhai", "Garbadhi", "Dheera Dheera"],
"YJHD": ["Badtameez Dil", "Kabira", "Ilahi"],
"Gangs of Wasseypur": ["Jiya Tu", "Hunter", "Keh Ke Lunga"],
"RRR": ["Naatu Naatu", "Komuram Bheemudo", "Dosti"]
}
for key in playlists:
if key.lower() in theme.lower():
return f"Playlist for {key}:\n" + "\n".join(playlists[key])
return {
"error": "playlist_not_found",
"reason": f"No playlist found for theme '{theme}'."
}
# -----------------------------
# SMART PLAYLIST RESOLVER
# -----------------------------
@tool
def smart_playlist_resolver(theme: str) -> str:
result = playlist_tool(theme=theme)
if isinstance(result, dict) and result.get("error") == "playlist_not_found":
return {
"tool": "websearch",
"query": f"top 5 {theme} songs"
}
return result
# -----------------------------
# SHOPPING LIST TOOL
# -----------------------------
@tool
def shopping_list_tool(menu_type: str) -> str:
if menu_type == "casual":
return (
"Shopping List (Casual):\n"
"- 2 kg Chicken\n"
"- Biryani Rice\n"
"- Curd\n"
"- Masala Powders\n"
"- Onions, Tomatoes\n"
"- Soft Drinks"
)
elif menu_type == "formal":
return (
"Shopping List (Formal):\n"
"- 3 kg Chicken\n"
"- Naan Ingredients\n"
"- Biryani Rice\n"
"- Butter & Cream\n"
"- Dessert Items\n"
"- Mocktail Mixes"
)
else:
return "Choose 'casual' or 'formal'."
# -----------------------------
# CLEANING DUTY TOOL
# -----------------------------
@tool
def cleaning_rotation_tool(dummy: str) -> str:
people = ["Srikanth", "Nitish", "Subrahmanya", "Rahul"]
return f"Cleaning duty today: {random.choice(people)}"
# -----------------------------
# COOKING DUTY TOOL
# -----------------------------
@tool
def cooking_rotation_tool(dummy: str) -> str:
people = ["Srikanth", "Nitish", "Subrahmanya", "Rahul"]
return f"Today's cook: {random.choice(people)}"
# -----------------------------
# BUDGET TOOL
# -----------------------------
@tool
def budget_tool(total_cost: float) -> str:
per_person = total_cost / 4
return f"Each person pays: €{per_person:.2f}"
# -----------------------------
# AGENT SETUP
# -----------------------------
agent = CodeAgent(
tools=[
WebSearchTool(),
VisitWebpageTool(),
suggest_menu,
catering_service_tool,
IndianMovieThemeTool(),
playlist_tool,
smart_playlist_resolver,
shopping_list_tool,
cleaning_rotation_tool,
cooking_rotation_tool,
budget_tool
],
model=model,
max_steps=10,
verbosity_level=2
)
def run_agent(query: str):
return agent.run(query)