FitPlan-module2 / model_api.py
saiganesh2004's picture
Update model_api.py
74ece72 verified
from huggingface_hub import InferenceClient
import os
import requests
import json
LLAMA_MODEL = "meta-llama/Llama-3.2-3B-Instruct" # Default model
USE_LOCAL_OLLAMA = False # Set to True if using local Ollama
def query_model(prompt):
"""
Query the Llama model with the given prompt
Supports both Hugging Face Inference API and local Ollama
"""
try:
if USE_LOCAL_OLLAMA:
return query_ollama(prompt)
else:
return query_huggingface(prompt)
except Exception as e:
return f"Error generating workout plan: {str(e)}"
def query_huggingface(prompt):
"""
Query Llama via Hugging Face Inference API
"""
HF_TOKEN = os.getenv("HF_TOKEN")
if not HF_TOKEN:
return "Error: HF_TOKEN not found. Please set your Hugging Face token in environment variables."
# Initialize the client with Llama model
client = InferenceClient(
model=LLAMA_MODEL,
token=HF_TOKEN
)
# Enhanced system prompt for better responses
system_prompt = """You are a certified professional fitness trainer with expertise in creating personalized workout plans.
Always provide complete, detailed workout plans with:
- Clear day-by-day structure
- Specific exercises with sets, reps, and rest periods
- Warm-up and cool-down recommendations
- Safety considerations based on user's profile
When asked for a 5-day plan, ensure ALL 5 days are included with clear day headers."""
# Make the API call
response = client.chat_completion(
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt}
],
max_tokens=3000,
temperature=0.7,
top_p=0.95
)
# Extract and return the response
workout_plan = response.choices[0].message.content
# Verify if the response contains all 5 days
days_found = sum([f"Day {i}" in workout_plan for i in range(1, 6)])
if days_found < 5:
# If incomplete, try one more time with more explicit instruction
retry_prompt = prompt + "\n\nIMPORTANT: The previous response was incomplete. Please ensure ALL 5 days (Day 1 through Day 5) are included in the plan. Each day should be clearly marked with 'Day X' header and include 4-6 exercises."
retry_response = client.chat_completion(
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": retry_prompt}
],
max_tokens=3000,
temperature=0.7
)
workout_plan = retry_response.choices[0].message.content
return workout_plan
def query_ollama(prompt):
"""
Query Llama via local Ollama (completely free, no API key needed)
"""
try:
response = requests.post(
"http://localhost:11434/api/generate",
json={
"model": "llama3.2:3b", # or "llama3.2:1b" for lighter model
"prompt": f"""You are a certified professional fitness trainer. Create a comprehensive 5-day workout plan.
{prompt}
Provide a complete, detailed 5-day workout plan with clear day headers, exercises, sets, reps, and rest periods.""",
"stream": False,
"max_tokens": 3000,
"temperature": 0.7
}
)
if response.status_code == 200:
return response.json()["response"]
else:
return f"Error: Ollama returned status code {response.status_code}"
except requests.exceptions.ConnectionError:
return "Error: Cannot connect to Ollama. Make sure Ollama is running locally (run 'ollama serve' in terminal)"
except Exception as e:
return f"Error with Ollama: {str(e)}"
def test_api_connection():
"""
Test function to verify API connection
"""
try:
if USE_LOCAL_OLLAMA:
# Test Ollama connection
response = requests.post(
"http://localhost:11434/api/generate",
json={
"model": "llama3.2:3b",
"prompt": "Say 'API connection successful' if you can read this.",
"stream": False,
"max_tokens": 50
}
)
if response.status_code == 200:
return True, "Ollama connection successful"
else:
return False, f"Ollama connection failed: {response.status_code}"
else:
# Test Hugging Face connection
HF_TOKEN = os.getenv("HF_TOKEN")
if not HF_TOKEN:
return False, "HF_TOKEN not found"
client = InferenceClient(
model=LLAMA_MODEL,
token=HF_TOKEN
)
response = client.chat_completion(
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Say 'API connection successful' if you can read this."}
],
max_tokens=50,
temperature=0.1
)
return True, f"API connection successful (using {LLAMA_MODEL})"
except Exception as e:
return False, f"API connection failed: {str(e)}"
def switch_model(model_name):
"""
Switch to a different Llama model
"""
global LLAMA_MODEL
LLAMA_MODEL = model_name
return f"Switched to {model_name}"
def set_ollama_mode(use_ollama):
"""
Switch between Hugging Face API and local Ollama
"""
global USE_LOCAL_OLLAMA
USE_LOCAL_OLLAMA = use_ollama
mode = "local Ollama" if use_ollama else "Hugging Face API"
return f"Switched to {mode} mode"