Spaces:
Sleeping
Sleeping
| # main.py | |
| # Takes multi-turn chat from console → Selects most relevant tool cards with RAG → | |
| # Directs Groq model with single "JSON-only" contract → | |
| # If model args are missing, generates QUESTION in format { "final": "<which fields are needed?>" } → | |
| # When args are complete, returns { "action": "<tool>", "args": {...} }. | |
| # NOTE: Existing logic is preserved; only new tool (generate_workout_plan) is integrated. | |
| import os | |
| import json | |
| import numpy as np | |
| from typing import Dict, Any, List | |
| from sentence_transformers import SentenceTransformer | |
| from groq import Client # Groq Python SDK | |
| from diet_tool import generate_diet_plan, get_diet_recommendations, calculate_nutrition_info | |
| # 1) Your Groq API key - for Hugging Face Spaces, set this as a secret | |
| GROQ_API_KEY = os.getenv("GROQ_API_KEY") | |
| if not GROQ_API_KEY: | |
| raise ValueError("GROQ_API_KEY environment variable is required. Please set it in your Hugging Face Space secrets.") | |
| client = Client(api_key=GROQ_API_KEY) | |
| # 2) Embedding model and tool cards (RAG) | |
| embed_model = SentenceTransformer("all-MiniLM-L6-v2") | |
| tool_cards = [ | |
| { | |
| "name": "one_rm_calculator", | |
| "description": ( | |
| "Calculate 1-rep-max (1RM) from a lifted weight (kg) and reps using the " | |
| "Wathan equation, and return JSON with the 1RM plus predicted weights for 1-10 reps." | |
| ) | |
| }, | |
| { | |
| "name": "get_user_profile", | |
| "description": "Fetch the current user's profile summary." | |
| }, | |
| { | |
| "name": "analyze_exercise_video", | |
| "description": "Analyze a given exercise video and return coaching feedback." | |
| }, | |
| { | |
| "name": "calculate_body_fat", | |
| "description": "Calculates user's body fat percentage using U.S. Navy formula. Arguments: user_id (string), sex, height_cm, weight_kg, neck_cm, waist_cm, hip_cm." | |
| }, | |
| { | |
| "name": "upsert_profile", | |
| "description": "Creates or updates user profile in database. Arguments: user_id, sex, height_cm, weight_kg, neck_cm, waist_cm, hip_cm." | |
| }, | |
| { | |
| "name": "generate_workout_plan", | |
| "description": "Creates block-periodized, auto-regulated, VBT/HRV compatible, multi-disciplinary training plan (with user profile + goal + weekly days count)." | |
| }, | |
| { | |
| "name": "generate_diet_plan", | |
| "description": "Creates personalized diet plan based on user profile and goals (daily/weekly plan, calorie calculation, macronutrient distribution)." | |
| }, | |
| { | |
| "name": "get_diet_recommendations", | |
| "description": "Creates diet recommendations and nutrition advice for user profile (BMR, TDEE, macro targets)." | |
| }, | |
| { | |
| "name": "calculate_nutrition_info", | |
| "description": "Calculates nutritional values for a specific food (calories, protein, carbs, fat, fiber)." | |
| } | |
| ] | |
| # 2.b) Tool parameter schemas (to properly direct and validate the model) | |
| tool_param_schemas: Dict[str, Dict[str, Any]] = { | |
| "one_rm_calculator": { | |
| "type": "object", | |
| "properties": { | |
| "weight_kg": { | |
| "type": "number", | |
| "description": "Weight lifted in kilograms." | |
| }, | |
| "reps": { | |
| "type": "integer", | |
| "minimum": 1, | |
| "maximum": 10, | |
| "description": "Number of repetitions (1-10)." | |
| }, | |
| "exercise": { # 🆕 nullable | |
| "oneOf": [ | |
| { | |
| "type": "string", | |
| "enum": [ | |
| "Bench Press", "Squat", "Deadlift", "Overhead Press", | |
| "Barbell Row", "Weighted Dips", "Weighted Pull Ups" | |
| ] | |
| }, | |
| { "type": "null" } | |
| ], | |
| "description": ( | |
| "Exercise name. Leave null or omit if the user did not specify. Only Bench Press, Squat, " | |
| "Deadlift, Overhead Press, Barbell Row, Weighted Dips, and Weighted Pull Ups are supported. " | |
| "You can translate the user-provided exercise name if needed, e.g. Barfiks -> Weighted Pull Ups, Bench -> Bench Press." | |
| ) | |
| }, | |
| }, | |
| "required": ["weight_kg", "reps"] | |
| }, | |
| "get_user_profile": { | |
| "type": "object", | |
| "properties": {}, | |
| "required": [] | |
| }, | |
| "analyze_exercise_video": { | |
| "type": "object", | |
| "properties": { | |
| "videoUrl": {"type": "string", "description": "Publicly accessible video URL"}, | |
| "exercise": {"type": "string", "description": "Exercise name, e.g., 'Squat'"} | |
| }, | |
| "required": ["videoUrl", "exercise"] | |
| }, | |
| "calculate_body_fat": { | |
| "type": "object", | |
| "properties": { | |
| "user_id": {"type": "string", "description": "User identifier"}, | |
| "sex": {"type": "string"}, | |
| "height_cm": {"type": "number"}, | |
| "weight_kg": {"type": "number"}, | |
| "neck_cm": {"type": "number"}, | |
| "waist_cm": {"type": "number"}, | |
| "hip_cm": {"type": "number"} | |
| }, | |
| "required": ["user_id", "height_cm", "weight_kg", "neck_cm", "waist_cm", "hip_cm", "sex"] | |
| }, | |
| "upsert_profile": { | |
| "type": "object", | |
| "properties": { | |
| "user_id": {"type": "string"}, | |
| "sex": {"type": "string"}, | |
| "height_cm": {"type": "number"}, | |
| "weight_kg": {"type": "number"}, | |
| "neck_cm": {"type": "number"}, | |
| "waist_cm": {"type": "number"}, | |
| "hip_cm": {"type": "number"} | |
| }, | |
| "required": ["user_id", "sex", "height_cm", "weight_kg", "neck_cm", "waist_cm"] # hip_cm optional | |
| }, | |
| "generate_workout_plan": { | |
| "type": "object", | |
| "properties": { | |
| "user_profile": {"type": "object", "description": "User profile/summary; may include injuries, recent_1RM etc."}, | |
| "goal": {"type": "string", "enum": ["hypertrophy", "strength", "fat_loss", "general_fitness"]}, | |
| "days_per_week": {"type": "integer", "minimum": 1, "maximum": 7}, | |
| "sport": {"type": "string", | |
| "enum": ["general", "powerlifting", "olympic_weightlifting", "crossfit", | |
| "bodybuilding", "endurance", "strongman", "calisthenics", "combat_sports"]}, | |
| "training_level": {"type": "string", "enum": ["novice", "intermediate", "advanced"]}, | |
| "sex": {"type": "string", "enum": ["male", "female"]}, | |
| "cycle_phase": {"type": "string", "enum": ["follicular", "luteal", "na"]}, | |
| "weekly_volume_pref": {"type": "string", "enum": ["low", "moderate", "high"]}, | |
| "block_type": {"type": "string", "enum": ["accumulation", "intensification", "peaking", "deload"]}, | |
| "mesocycle_length": {"type": "integer"}, | |
| "equipment": {"type": "array", "items": {"type": "string"}}, | |
| "weak_points": {"type": "array", "items": {"type": "string"}}, | |
| "sticking_points": {"type": "object"}, | |
| "auto_accessories": {"type": "boolean"}, | |
| "vbt_available": {"type": "boolean"}, | |
| "readiness_score": {"type": "number", "minimum": 0, "maximum": 10}, | |
| "constraints": {"type": "object"}, | |
| "swap_exercise_if_unavailable": {"type": "boolean"} | |
| }, | |
| "required": ["user_profile", "goal", "days_per_week"] | |
| }, | |
| "generate_diet_plan": { | |
| "type": "object", | |
| "properties": { | |
| "user_profile": { | |
| "type": "object", | |
| "properties": { | |
| "age": {"type": "integer", "minimum": 15, "maximum": 100}, | |
| "sex": {"type": "string", "enum": ["male", "female"]}, | |
| "height_cm": {"type": "number", "minimum": 100, "maximum": 250}, | |
| "weight_kg": {"type": "number", "minimum": 30, "maximum": 300}, | |
| "activity_level": {"type": "string", "enum": ["sedentary", "light", "moderate", "active", "very_active"]}, | |
| "goal": {"type": "string", "enum": ["weight_loss", "muscle_gain", "maintenance", "keto", "mediterranean"]}, | |
| "dietary_restrictions": {"type": "array", "items": {"type": "string"}}, | |
| "allergies": {"type": "array", "items": {"type": "string"}}, | |
| "preferences": {"type": "array", "items": {"type": "string"}} | |
| }, | |
| "required": ["age", "sex", "height_cm", "weight_kg", "activity_level", "goal"] | |
| }, | |
| "plan_type": {"type": "string", "enum": ["daily", "weekly"], "default": "daily"} | |
| }, | |
| "required": ["user_profile"] | |
| }, | |
| "get_diet_recommendations": { | |
| "type": "object", | |
| "properties": { | |
| "user_profile": { | |
| "type": "object", | |
| "properties": { | |
| "age": {"type": "integer", "minimum": 15, "maximum": 100}, | |
| "sex": {"type": "string", "enum": ["male", "female"]}, | |
| "height_cm": {"type": "number", "minimum": 100, "maximum": 250}, | |
| "weight_kg": {"type": "number", "minimum": 30, "maximum": 300}, | |
| "activity_level": {"type": "string", "enum": ["sedentary", "light", "moderate", "active", "very_active"]}, | |
| "goal": {"type": "string", "enum": ["weight_loss", "muscle_gain", "maintenance", "keto", "mediterranean"]}, | |
| "dietary_restrictions": {"type": "array", "items": {"type": "string"}}, | |
| "allergies": {"type": "array", "items": {"type": "string"}}, | |
| "preferences": {"type": "array", "items": {"type": "string"}} | |
| }, | |
| "required": ["age", "sex", "height_cm", "weight_kg", "activity_level", "goal"] | |
| } | |
| }, | |
| "required": ["user_profile"] | |
| }, | |
| "calculate_nutrition_info": { | |
| "type": "object", | |
| "properties": { | |
| "food_name": {"type": "string", "description": "Food name"}, | |
| "portion_grams": {"type": "number", "minimum": 1, "maximum": 1000, "description": "Portion amount (grams)"} | |
| }, | |
| "required": ["food_name", "portion_grams"] | |
| }, | |
| "expand_diet_database": { | |
| "type": "object", | |
| "properties": { | |
| "max_per_category": {"type": "integer", "minimum": 1, "maximum": 20, "description": "Maximum number of foods per category", "default": 5}, | |
| "output_filename": {"type": "string", "description": "Output filename (saved to outputs/ folder)", "default": "usda_foods_database.json"} | |
| }, | |
| "required": [] | |
| } | |
| } | |
| descs = [c["description"] for c in tool_cards] | |
| embs = embed_model.encode(descs) | |
| # 3) Simple cosine-similarity top-K retrieval (RAG) | |
| def retrieve_tools(query: str, k: int = 5): | |
| q_emb = embed_model.encode([query])[0] | |
| sims = (embs @ q_emb) / (np.linalg.norm(embs, axis=1) * np.linalg.norm(q_emb) + 1e-10) | |
| idxs = np.argsort(-sims)[:k] | |
| return [tool_cards[i] for i in idxs] | |
| # 3.b) Find missing fields according to schema | |
| def find_missing_required(tool_name: str, args: Dict[str, Any]) -> List[str]: | |
| schema = tool_param_schemas.get(tool_name) | |
| if not schema: | |
| return [] | |
| required = schema.get("required", []) | |
| missing = [] | |
| for key in required: | |
| if key not in args or args[key] in ("", None): | |
| missing.append(key) | |
| return missing | |
| # 3.c) Diet tool functions | |
| def execute_diet_tool(action: str, args: Dict[str, Any]) -> Dict[str, Any]: | |
| """Execute diet tool functions and save outputs to outputs folder""" | |
| try: | |
| # Create outputs folder (if it doesn't exist) | |
| os.makedirs("outputs", exist_ok=True) | |
| # Create timestamp | |
| from datetime import datetime | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| if action == "generate_diet_plan": | |
| user_profile = args["user_profile"] | |
| plan_type = args.get("plan_type", "daily") | |
| result = generate_diet_plan(user_profile, plan_type) | |
| # Save output to file | |
| filename = f"outputs/diet_plan_{plan_type}_{timestamp}.json" | |
| with open(filename, 'w', encoding='utf-8') as f: | |
| json.dump({ | |
| "action": action, | |
| "user_profile": user_profile, | |
| "plan_type": plan_type, | |
| "result": result, | |
| "timestamp": timestamp | |
| }, f, ensure_ascii=False, indent=2) | |
| return { | |
| "success": True, | |
| "result": result, | |
| "message": f"{plan_type.capitalize()} diet plan successfully created and saved to {filename} file." | |
| } | |
| elif action == "get_diet_recommendations": | |
| user_profile = args["user_profile"] | |
| result = get_diet_recommendations(user_profile) | |
| # Save output to file | |
| filename = f"outputs/diet_recommendations_{timestamp}.json" | |
| with open(filename, 'w', encoding='utf-8') as f: | |
| json.dump({ | |
| "action": action, | |
| "user_profile": user_profile, | |
| "result": result, | |
| "timestamp": timestamp | |
| }, f, ensure_ascii=False, indent=2) | |
| return { | |
| "success": True, | |
| "result": result, | |
| "message": f"Diet recommendations successfully created and saved to {filename} file." | |
| } | |
| elif action == "calculate_nutrition_info": | |
| food_name = args["food_name"] | |
| portion_grams = args["portion_grams"] | |
| result = calculate_nutrition_info(food_name, portion_grams) | |
| # Save output to file | |
| filename = f"outputs/nutrition_info_{timestamp}.json" | |
| with open(filename, 'w', encoding='utf-8') as f: | |
| json.dump({ | |
| "action": action, | |
| "food_name": food_name, | |
| "portion_grams": portion_grams, | |
| "result": result, | |
| "timestamp": timestamp | |
| }, f, ensure_ascii=False, indent=2) | |
| return { | |
| "success": True, | |
| "result": result, | |
| "message": f"Nutrition calculation completed and saved to {filename} file." | |
| } | |
| elif action == "expand_diet_database": | |
| max_per_category = args.get("max_per_category", 5) | |
| output_filename = args.get("output_filename", "usda_foods_database.json") | |
| # Import and run expand_diet_data.py | |
| try: | |
| from expand_diet_data import expand_diet_data_from_api | |
| output_path = os.path.join("outputs", output_filename) | |
| expand_diet_data_from_api(output_path, max_per_category) | |
| return { | |
| "success": True, | |
| "result": { | |
| "max_per_category": max_per_category, | |
| "output_path": output_path | |
| }, | |
| "message": f"USDA nutrition database successfully created: {output_path}" | |
| } | |
| except Exception as e: | |
| return { | |
| "success": False, | |
| "error": f"Error while expanding database: {str(e)}" | |
| } | |
| else: | |
| return { | |
| "success": False, | |
| "error": f"Unknown diet tool: {action}" | |
| } | |
| except Exception as e: | |
| return { | |
| "success": False, | |
| "error": f"Error while running diet tool: {str(e)}" | |
| } | |
| # 4) RAG + Groq tool-call router (single turn) | |
| def _build_messages(user_query: str, cards: List[Dict[str, str]], chat_history: List[Dict[str, str]]) -> List[Dict[str, str]]: | |
| # Explicitly provide tool + schema information to the model | |
| tools_lines = [] | |
| for c in cards: | |
| name = c["name"] | |
| schema = tool_param_schemas.get(name, {}) | |
| tools_lines.append( | |
| f"{name}: {c['description']}\n" | |
| f"PARAMETERS(JSON Schema): {json.dumps(schema, ensure_ascii=False)}" | |
| ) | |
| tools_block = "\n\n".join(tools_lines) | |
| system_prompt = ( | |
| "You are an assistant that only returns JSON. Do not write any other explanations.\n" | |
| "IF TOOL IS NEEDED: produce exactly this schema → {\"action\":\"<tool_name>\",\"args\":{...}}\n" | |
| "IF TOOL IS NOT NEEDED: {\"final\":\"...\"}\n" | |
| "IF ARGUMENTS ARE MISSING: never make up; {\"final\":\"<ask user for required fields in ENGLISH, short and clear>\"}\n" | |
| "Do not go outside JSON, do not add text before/after.\n" | |
| "Tool name must be one from the list and argument names must exactly match PARAMETER schema." | |
| ) | |
| # Clean chat history to remove unsupported fields like 'metadata' | |
| cleaned_history = [] | |
| for msg in chat_history: | |
| if isinstance(msg, dict) and "role" in msg and "content" in msg: | |
| cleaned_msg = {"role": msg["role"], "content": msg["content"]} | |
| cleaned_history.append(cleaned_msg) | |
| messages = [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "system", "content": tools_block}, | |
| *cleaned_history, # previous turns (cleaned) | |
| {"role": "user", "content": user_query}, | |
| ] | |
| return messages | |
| def run_rag_tool_router(user_query: str, chat_history: List[Dict[str, str]]): | |
| cards = retrieve_tools(user_query, k=5) | |
| messages = _build_messages(user_query, cards, chat_history) | |
| resp = client.chat.completions.create( | |
| model="openai/gpt-oss-120b", | |
| messages=messages, | |
| max_tokens=240, | |
| temperature=0.0 | |
| ) | |
| text = resp.choices[0].message.content.strip() | |
| try: | |
| obj = json.loads(text) | |
| except json.JSONDecodeError: | |
| return {"error": "invalid_json", "raw": text} | |
| # If it's final (question/message) | |
| if isinstance(obj, dict) and "final" in obj: | |
| return obj | |
| # Otherwise we expect action+args | |
| if not isinstance(obj, dict) or "action" not in obj or "args" not in obj: | |
| return {"error": "invalid_shape", "raw": obj} | |
| action = obj["action"] | |
| args = obj.get("args", {}) | |
| if not isinstance(action, str) or not isinstance(args, dict): | |
| return {"error": "invalid_types", "raw": obj} | |
| # Missing field check (additional security – model should have already asked) | |
| missing = find_missing_required(action, args) | |
| if missing: | |
| need = ", ".join(missing) | |
| return {"final": f"Please provide the following fields: {need}"} | |
| # Here you can add type validation/type conversion if you want | |
| return {"action": action, "args": args} | |
| # 5) Multi-turn chat loop | |
| if __name__ == "__main__": | |
| chat_history: List[Dict[str, str]] = [] | |
| print("Multi-turn RAG+Tool Router. Leave empty line to exit.\n") | |
| while True: | |
| user_msg = input("Enter your prompt: ").strip() | |
| if not user_msg: | |
| break | |
| result = run_rag_tool_router(user_msg, chat_history) | |
| # Update history | |
| chat_history.append({"role": "user", "content": user_msg}) | |
| # Assistant response | |
| if isinstance(result, dict) and "final" in result: | |
| print(json.dumps(result, ensure_ascii=False, indent=2)) | |
| chat_history.append({"role": "assistant", "content": json.dumps(result, ensure_ascii=False)}) | |
| continue | |
| if isinstance(result, dict) and "action" in result: | |
| print(json.dumps(result, ensure_ascii=False, indent=2)) | |
| chat_history.append({"role": "assistant", "content": json.dumps(result, ensure_ascii=False)}) | |
| # Diet tool check and execution | |
| action = result["action"] | |
| args = result.get("args", {}) | |
| if action in ["generate_diet_plan", "get_diet_recommendations", "calculate_nutrition_info", "expand_diet_database"]: | |
| diet_result = execute_diet_tool(action, args) | |
| print(f"\nDiet Tool Result:") | |
| print(json.dumps(diet_result, ensure_ascii=False, indent=2)) | |
| chat_history.append({"role": "assistant", "content": json.dumps(diet_result, ensure_ascii=False)}) | |
| else: | |
| # Placeholder for other tools | |
| print("(Implementation needed for other tools)") | |
| continue | |
| # Error/raw response | |
| print(json.dumps(result, ensure_ascii=False, indent=2)) | |
| chat_history.append({"role": "assistant", "content": json.dumps(result, ensure_ascii=False)}) | |