RAG-DietAssistant / main3.py
Beable's picture
Upload 6 files
a144435 verified
# main.py
# Takes multi-turn chat from console → Selects most relevant tool cards with RAG →
# Directs Groq model with single "JSON-only" contract →
# If model args are missing, generates QUESTION in format { "final": "<which fields are needed?>" } →
# When args are complete, returns { "action": "<tool>", "args": {...} }.
# NOTE: Existing logic is preserved; only new tool (generate_workout_plan) is integrated.
import os
import json
import numpy as np
from typing import Dict, Any, List
from sentence_transformers import SentenceTransformer
from groq import Client # Groq Python SDK
from diet_tool import generate_diet_plan, get_diet_recommendations, calculate_nutrition_info
# 1) Your Groq API key - for Hugging Face Spaces, set this as a secret
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
if not GROQ_API_KEY:
raise ValueError("GROQ_API_KEY environment variable is required. Please set it in your Hugging Face Space secrets.")
client = Client(api_key=GROQ_API_KEY)
# 2) Embedding model and tool cards (RAG)
embed_model = SentenceTransformer("all-MiniLM-L6-v2")
tool_cards = [
{
"name": "one_rm_calculator",
"description": (
"Calculate 1-rep-max (1RM) from a lifted weight (kg) and reps using the "
"Wathan equation, and return JSON with the 1RM plus predicted weights for 1-10 reps."
)
},
{
"name": "get_user_profile",
"description": "Fetch the current user's profile summary."
},
{
"name": "analyze_exercise_video",
"description": "Analyze a given exercise video and return coaching feedback."
},
{
"name": "calculate_body_fat",
"description": "Calculates user's body fat percentage using U.S. Navy formula. Arguments: user_id (string), sex, height_cm, weight_kg, neck_cm, waist_cm, hip_cm."
},
{
"name": "upsert_profile",
"description": "Creates or updates user profile in database. Arguments: user_id, sex, height_cm, weight_kg, neck_cm, waist_cm, hip_cm."
},
{
"name": "generate_workout_plan",
"description": "Creates block-periodized, auto-regulated, VBT/HRV compatible, multi-disciplinary training plan (with user profile + goal + weekly days count)."
},
{
"name": "generate_diet_plan",
"description": "Creates personalized diet plan based on user profile and goals (daily/weekly plan, calorie calculation, macronutrient distribution)."
},
{
"name": "get_diet_recommendations",
"description": "Creates diet recommendations and nutrition advice for user profile (BMR, TDEE, macro targets)."
},
{
"name": "calculate_nutrition_info",
"description": "Calculates nutritional values for a specific food (calories, protein, carbs, fat, fiber)."
}
]
# 2.b) Tool parameter schemas (to properly direct and validate the model)
tool_param_schemas: Dict[str, Dict[str, Any]] = {
"one_rm_calculator": {
"type": "object",
"properties": {
"weight_kg": {
"type": "number",
"description": "Weight lifted in kilograms."
},
"reps": {
"type": "integer",
"minimum": 1,
"maximum": 10,
"description": "Number of repetitions (1-10)."
},
"exercise": { # 🆕 nullable
"oneOf": [
{
"type": "string",
"enum": [
"Bench Press", "Squat", "Deadlift", "Overhead Press",
"Barbell Row", "Weighted Dips", "Weighted Pull Ups"
]
},
{ "type": "null" }
],
"description": (
"Exercise name. Leave null or omit if the user did not specify. Only Bench Press, Squat, "
"Deadlift, Overhead Press, Barbell Row, Weighted Dips, and Weighted Pull Ups are supported. "
"You can translate the user-provided exercise name if needed, e.g. Barfiks -> Weighted Pull Ups, Bench -> Bench Press."
)
},
},
"required": ["weight_kg", "reps"]
},
"get_user_profile": {
"type": "object",
"properties": {},
"required": []
},
"analyze_exercise_video": {
"type": "object",
"properties": {
"videoUrl": {"type": "string", "description": "Publicly accessible video URL"},
"exercise": {"type": "string", "description": "Exercise name, e.g., 'Squat'"}
},
"required": ["videoUrl", "exercise"]
},
"calculate_body_fat": {
"type": "object",
"properties": {
"user_id": {"type": "string", "description": "User identifier"},
"sex": {"type": "string"},
"height_cm": {"type": "number"},
"weight_kg": {"type": "number"},
"neck_cm": {"type": "number"},
"waist_cm": {"type": "number"},
"hip_cm": {"type": "number"}
},
"required": ["user_id", "height_cm", "weight_kg", "neck_cm", "waist_cm", "hip_cm", "sex"]
},
"upsert_profile": {
"type": "object",
"properties": {
"user_id": {"type": "string"},
"sex": {"type": "string"},
"height_cm": {"type": "number"},
"weight_kg": {"type": "number"},
"neck_cm": {"type": "number"},
"waist_cm": {"type": "number"},
"hip_cm": {"type": "number"}
},
"required": ["user_id", "sex", "height_cm", "weight_kg", "neck_cm", "waist_cm"] # hip_cm optional
},
"generate_workout_plan": {
"type": "object",
"properties": {
"user_profile": {"type": "object", "description": "User profile/summary; may include injuries, recent_1RM etc."},
"goal": {"type": "string", "enum": ["hypertrophy", "strength", "fat_loss", "general_fitness"]},
"days_per_week": {"type": "integer", "minimum": 1, "maximum": 7},
"sport": {"type": "string",
"enum": ["general", "powerlifting", "olympic_weightlifting", "crossfit",
"bodybuilding", "endurance", "strongman", "calisthenics", "combat_sports"]},
"training_level": {"type": "string", "enum": ["novice", "intermediate", "advanced"]},
"sex": {"type": "string", "enum": ["male", "female"]},
"cycle_phase": {"type": "string", "enum": ["follicular", "luteal", "na"]},
"weekly_volume_pref": {"type": "string", "enum": ["low", "moderate", "high"]},
"block_type": {"type": "string", "enum": ["accumulation", "intensification", "peaking", "deload"]},
"mesocycle_length": {"type": "integer"},
"equipment": {"type": "array", "items": {"type": "string"}},
"weak_points": {"type": "array", "items": {"type": "string"}},
"sticking_points": {"type": "object"},
"auto_accessories": {"type": "boolean"},
"vbt_available": {"type": "boolean"},
"readiness_score": {"type": "number", "minimum": 0, "maximum": 10},
"constraints": {"type": "object"},
"swap_exercise_if_unavailable": {"type": "boolean"}
},
"required": ["user_profile", "goal", "days_per_week"]
},
"generate_diet_plan": {
"type": "object",
"properties": {
"user_profile": {
"type": "object",
"properties": {
"age": {"type": "integer", "minimum": 15, "maximum": 100},
"sex": {"type": "string", "enum": ["male", "female"]},
"height_cm": {"type": "number", "minimum": 100, "maximum": 250},
"weight_kg": {"type": "number", "minimum": 30, "maximum": 300},
"activity_level": {"type": "string", "enum": ["sedentary", "light", "moderate", "active", "very_active"]},
"goal": {"type": "string", "enum": ["weight_loss", "muscle_gain", "maintenance", "keto", "mediterranean"]},
"dietary_restrictions": {"type": "array", "items": {"type": "string"}},
"allergies": {"type": "array", "items": {"type": "string"}},
"preferences": {"type": "array", "items": {"type": "string"}}
},
"required": ["age", "sex", "height_cm", "weight_kg", "activity_level", "goal"]
},
"plan_type": {"type": "string", "enum": ["daily", "weekly"], "default": "daily"}
},
"required": ["user_profile"]
},
"get_diet_recommendations": {
"type": "object",
"properties": {
"user_profile": {
"type": "object",
"properties": {
"age": {"type": "integer", "minimum": 15, "maximum": 100},
"sex": {"type": "string", "enum": ["male", "female"]},
"height_cm": {"type": "number", "minimum": 100, "maximum": 250},
"weight_kg": {"type": "number", "minimum": 30, "maximum": 300},
"activity_level": {"type": "string", "enum": ["sedentary", "light", "moderate", "active", "very_active"]},
"goal": {"type": "string", "enum": ["weight_loss", "muscle_gain", "maintenance", "keto", "mediterranean"]},
"dietary_restrictions": {"type": "array", "items": {"type": "string"}},
"allergies": {"type": "array", "items": {"type": "string"}},
"preferences": {"type": "array", "items": {"type": "string"}}
},
"required": ["age", "sex", "height_cm", "weight_kg", "activity_level", "goal"]
}
},
"required": ["user_profile"]
},
"calculate_nutrition_info": {
"type": "object",
"properties": {
"food_name": {"type": "string", "description": "Food name"},
"portion_grams": {"type": "number", "minimum": 1, "maximum": 1000, "description": "Portion amount (grams)"}
},
"required": ["food_name", "portion_grams"]
},
"expand_diet_database": {
"type": "object",
"properties": {
"max_per_category": {"type": "integer", "minimum": 1, "maximum": 20, "description": "Maximum number of foods per category", "default": 5},
"output_filename": {"type": "string", "description": "Output filename (saved to outputs/ folder)", "default": "usda_foods_database.json"}
},
"required": []
}
}
descs = [c["description"] for c in tool_cards]
embs = embed_model.encode(descs)
# 3) Simple cosine-similarity top-K retrieval (RAG)
def retrieve_tools(query: str, k: int = 5):
q_emb = embed_model.encode([query])[0]
sims = (embs @ q_emb) / (np.linalg.norm(embs, axis=1) * np.linalg.norm(q_emb) + 1e-10)
idxs = np.argsort(-sims)[:k]
return [tool_cards[i] for i in idxs]
# 3.b) Find missing fields according to schema
def find_missing_required(tool_name: str, args: Dict[str, Any]) -> List[str]:
schema = tool_param_schemas.get(tool_name)
if not schema:
return []
required = schema.get("required", [])
missing = []
for key in required:
if key not in args or args[key] in ("", None):
missing.append(key)
return missing
# 3.c) Diet tool functions
def execute_diet_tool(action: str, args: Dict[str, Any]) -> Dict[str, Any]:
"""Execute diet tool functions and save outputs to outputs folder"""
try:
# Create outputs folder (if it doesn't exist)
os.makedirs("outputs", exist_ok=True)
# Create timestamp
from datetime import datetime
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
if action == "generate_diet_plan":
user_profile = args["user_profile"]
plan_type = args.get("plan_type", "daily")
result = generate_diet_plan(user_profile, plan_type)
# Save output to file
filename = f"outputs/diet_plan_{plan_type}_{timestamp}.json"
with open(filename, 'w', encoding='utf-8') as f:
json.dump({
"action": action,
"user_profile": user_profile,
"plan_type": plan_type,
"result": result,
"timestamp": timestamp
}, f, ensure_ascii=False, indent=2)
return {
"success": True,
"result": result,
"message": f"{plan_type.capitalize()} diet plan successfully created and saved to {filename} file."
}
elif action == "get_diet_recommendations":
user_profile = args["user_profile"]
result = get_diet_recommendations(user_profile)
# Save output to file
filename = f"outputs/diet_recommendations_{timestamp}.json"
with open(filename, 'w', encoding='utf-8') as f:
json.dump({
"action": action,
"user_profile": user_profile,
"result": result,
"timestamp": timestamp
}, f, ensure_ascii=False, indent=2)
return {
"success": True,
"result": result,
"message": f"Diet recommendations successfully created and saved to {filename} file."
}
elif action == "calculate_nutrition_info":
food_name = args["food_name"]
portion_grams = args["portion_grams"]
result = calculate_nutrition_info(food_name, portion_grams)
# Save output to file
filename = f"outputs/nutrition_info_{timestamp}.json"
with open(filename, 'w', encoding='utf-8') as f:
json.dump({
"action": action,
"food_name": food_name,
"portion_grams": portion_grams,
"result": result,
"timestamp": timestamp
}, f, ensure_ascii=False, indent=2)
return {
"success": True,
"result": result,
"message": f"Nutrition calculation completed and saved to {filename} file."
}
elif action == "expand_diet_database":
max_per_category = args.get("max_per_category", 5)
output_filename = args.get("output_filename", "usda_foods_database.json")
# Import and run expand_diet_data.py
try:
from expand_diet_data import expand_diet_data_from_api
output_path = os.path.join("outputs", output_filename)
expand_diet_data_from_api(output_path, max_per_category)
return {
"success": True,
"result": {
"max_per_category": max_per_category,
"output_path": output_path
},
"message": f"USDA nutrition database successfully created: {output_path}"
}
except Exception as e:
return {
"success": False,
"error": f"Error while expanding database: {str(e)}"
}
else:
return {
"success": False,
"error": f"Unknown diet tool: {action}"
}
except Exception as e:
return {
"success": False,
"error": f"Error while running diet tool: {str(e)}"
}
# 4) RAG + Groq tool-call router (single turn)
def _build_messages(user_query: str, cards: List[Dict[str, str]], chat_history: List[Dict[str, str]]) -> List[Dict[str, str]]:
# Explicitly provide tool + schema information to the model
tools_lines = []
for c in cards:
name = c["name"]
schema = tool_param_schemas.get(name, {})
tools_lines.append(
f"{name}: {c['description']}\n"
f"PARAMETERS(JSON Schema): {json.dumps(schema, ensure_ascii=False)}"
)
tools_block = "\n\n".join(tools_lines)
system_prompt = (
"You are an assistant that only returns JSON. Do not write any other explanations.\n"
"IF TOOL IS NEEDED: produce exactly this schema → {\"action\":\"<tool_name>\",\"args\":{...}}\n"
"IF TOOL IS NOT NEEDED: {\"final\":\"...\"}\n"
"IF ARGUMENTS ARE MISSING: never make up; {\"final\":\"<ask user for required fields in ENGLISH, short and clear>\"}\n"
"Do not go outside JSON, do not add text before/after.\n"
"Tool name must be one from the list and argument names must exactly match PARAMETER schema."
)
# Clean chat history to remove unsupported fields like 'metadata'
cleaned_history = []
for msg in chat_history:
if isinstance(msg, dict) and "role" in msg and "content" in msg:
cleaned_msg = {"role": msg["role"], "content": msg["content"]}
cleaned_history.append(cleaned_msg)
messages = [
{"role": "system", "content": system_prompt},
{"role": "system", "content": tools_block},
*cleaned_history, # previous turns (cleaned)
{"role": "user", "content": user_query},
]
return messages
def run_rag_tool_router(user_query: str, chat_history: List[Dict[str, str]]):
cards = retrieve_tools(user_query, k=5)
messages = _build_messages(user_query, cards, chat_history)
resp = client.chat.completions.create(
model="openai/gpt-oss-120b",
messages=messages,
max_tokens=240,
temperature=0.0
)
text = resp.choices[0].message.content.strip()
try:
obj = json.loads(text)
except json.JSONDecodeError:
return {"error": "invalid_json", "raw": text}
# If it's final (question/message)
if isinstance(obj, dict) and "final" in obj:
return obj
# Otherwise we expect action+args
if not isinstance(obj, dict) or "action" not in obj or "args" not in obj:
return {"error": "invalid_shape", "raw": obj}
action = obj["action"]
args = obj.get("args", {})
if not isinstance(action, str) or not isinstance(args, dict):
return {"error": "invalid_types", "raw": obj}
# Missing field check (additional security – model should have already asked)
missing = find_missing_required(action, args)
if missing:
need = ", ".join(missing)
return {"final": f"Please provide the following fields: {need}"}
# Here you can add type validation/type conversion if you want
return {"action": action, "args": args}
# 5) Multi-turn chat loop
if __name__ == "__main__":
chat_history: List[Dict[str, str]] = []
print("Multi-turn RAG+Tool Router. Leave empty line to exit.\n")
while True:
user_msg = input("Enter your prompt: ").strip()
if not user_msg:
break
result = run_rag_tool_router(user_msg, chat_history)
# Update history
chat_history.append({"role": "user", "content": user_msg})
# Assistant response
if isinstance(result, dict) and "final" in result:
print(json.dumps(result, ensure_ascii=False, indent=2))
chat_history.append({"role": "assistant", "content": json.dumps(result, ensure_ascii=False)})
continue
if isinstance(result, dict) and "action" in result:
print(json.dumps(result, ensure_ascii=False, indent=2))
chat_history.append({"role": "assistant", "content": json.dumps(result, ensure_ascii=False)})
# Diet tool check and execution
action = result["action"]
args = result.get("args", {})
if action in ["generate_diet_plan", "get_diet_recommendations", "calculate_nutrition_info", "expand_diet_database"]:
diet_result = execute_diet_tool(action, args)
print(f"\nDiet Tool Result:")
print(json.dumps(diet_result, ensure_ascii=False, indent=2))
chat_history.append({"role": "assistant", "content": json.dumps(diet_result, ensure_ascii=False)})
else:
# Placeholder for other tools
print("(Implementation needed for other tools)")
continue
# Error/raw response
print(json.dumps(result, ensure_ascii=False, indent=2))
chat_history.append({"role": "assistant", "content": json.dumps(result, ensure_ascii=False)})