FitPlan-AI / src /streamlit_app.py
ArumugaSelvi's picture
Update src/streamlit_app.py
66364b2 verified
import streamlit as st
import torch
import random
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
# ---------------- PAGE CONFIG ----------------
st.set_page_config(page_title="FitPlan AI", page_icon="💪", layout="centered")
# ---------------- SESSION STATE ----------------
if "page" not in st.session_state:
st.session_state.page = "landing"
# ---------------- LANDING PAGE ----------------
if st.session_state.page == "landing":
st.markdown("""
<style>
[data-testid="stAppViewContainer"] {
background-image: url("https://images.unsplash.com/photo-1483721310020-03333e577078");
background-size: cover;
background-position: center;
}
.center-box {
text-align:center;
margin-top:200px;
font-size:48px;
font-weight:bold;
color:black;
}
</style>
""", unsafe_allow_html=True)
st.markdown('<div class="center-box">💪 Welcome to FitPlan AI</div>', unsafe_allow_html=True)
if st.button("🚀 Get Started"):
st.session_state.page = "main"
st.rerun()
# ---------------- MAIN PAGE ----------------
elif st.session_state.page == "main":
st.title("💪 FitPlan AI")
# ---------------- INPUTS ----------------
name = st.text_input("Name *")
gender = st.selectbox("Gender", ["Male", "Female", "Other"])
height_cm = st.number_input("Height (cm) *", min_value=0.0)
weight_kg = st.number_input("Weight (kg) *", min_value=0.0)
goal = st.selectbox(
"Fitness Goal",
["Build Muscle", "Weight Loss", "Strength Gain", "Abs Building", "Flexible"]
)
equipment = st.multiselect(
"Equipment",
["Dumbbells", "Resistance Band", "Yoga Mat", "No Equipment"]
)
fitness_level = st.radio(
"Fitness Level",
["Beginner", "Intermediate", "Advanced"]
)
# ---------------- BMI FUNCTIONS ----------------
def calculate_bmi(h, w):
return round(w / ((h / 100) ** 2), 2)
def bmi_category(b):
if b < 18.5:
return "Underweight"
elif b < 25:
return "Normal"
elif b < 30:
return "Overweight"
else:
return "Obese"
# ---------------- BMI BUTTON ----------------
if st.button("Generate BMI"):
if height_cm <= 0 or weight_kg <= 0:
st.error("Enter valid height and weight.")
else:
bmi_value = calculate_bmi(height_cm, weight_kg)
st.session_state["bmi"] = bmi_value
st.success(f"BMI: {bmi_value:.2f} ({bmi_category(bmi_value)})")
bmi = st.session_state.get("bmi", None)
# ---------------- MODEL LOADING ----------------
@st.cache_resource
def load_model():
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base")
return tokenizer, model
tokenizer, model = load_model()
# ---------------- SUBMIT PROFILE ----------------
if st.button("Submit Profile"):
if not name:
st.error("Enter your name.")
elif height_cm <= 0 or weight_kg <= 0:
st.error("Enter valid height & weight.")
elif not equipment:
st.error("Select equipment.")
elif bmi is None:
st.error("Generate BMI first.")
else:
st.success("Profile Submitted Successfully!")
bmi_status = bmi_category(bmi)
equipment_list = ", ".join(equipment)
# Add randomness token to force variation
random_token = random.randint(1, 1000000)
prompt = f"""
You are a certified professional fitness trainer.
Random Seed: {random_token}
Generate a structured 5-day workout plan based on the following user profile.
User Profile:
- Name: {name}
- Gender: {gender}
- BMI: {bmi:.2f} ({bmi_status})
- Goal: {goal}
- Fitness Level: {fitness_level}
- Available Equipment: {equipment_list}
Instructions:
1. Divide the plan clearly into Day 1 to Day 5.
2. Under each day, list 4-6 exercises.
3. For each exercise include:
- Exercise Name
- Sets
- Reps
- Rest Time
4. Keep exercises appropriate for fitness level.
5. Do NOT include explanations outside workout plan.
Only return the workout plan.
"""
with st.spinner("Generating Workout Plan..."):
inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
outputs = model.generate(
**inputs,
max_new_tokens=400,
do_sample=True,
temperature=0.9, # Increased randomness
top_p=0.95,
top_k=50,
repetition_penalty=1.5,
num_beams=1 # Important: disables deterministic beam search
)
result = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
st.subheader("🏋️ Your Personalized Workout Plan")
st.write(result)