AnthroBot / app.py
SallySims's picture
Update app.py
c4ee535 verified
raw
history blame
8.88 kB
## Deploying on HuggingFace
import streamlit as st
import pandas as pd
import torch
import os
from huggingface_hub import login
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel, PeftConfig
import io
# Login using Hugging Face token
login(token=os.getenv("HUGGINGFACEHUB_TOKEN"))
st.set_page_config(page_title="AnthroBot", page_icon="๐Ÿค–", layout="centered")
# Load model & tokenizer
@st.cache_resource
def load_model():
try:
peft_config = PeftConfig.from_pretrained("SallySims/AnthroBot_Model_Lora")
base_model = AutoModelForCausalLM.from_pretrained(
peft_config.base_model_name_or_path,
torch_dtype=torch.float16,
device_map="auto"
)
model = PeftModel.from_pretrained(base_model, "SallySims/AnthroBot_Model_Lora")
model.eval()
tokenizer = AutoTokenizer.from_pretrained(peft_config.base_model_name_or_path)
tokenizer.pad_token = tokenizer.eos_token
st.write("โœ… Model and tokenizer loaded successfully.")
return model, tokenizer
except Exception as e:
st.error(f"Error loading model: {str(e)}")
raise e
model, tokenizer = load_model()
# Calculate anthropometric metrics
def calculate_metrics(age, sex, height_cm, weight_kg, wc_cm):
# Convert height and WC to meters
height_m = height_cm / 100
wc_m = wc_cm / 100
# BMI
bmi = weight_kg / (height_m ** 2)
# WHTR
whtr = wc_m / height_m
# BFP (Boer's formula approximation)
if sex.lower() == "male":
bfp = (1.20 * bmi) + (0.23 * age) - 16.2
else:
bfp = (1.20 * bmi) + (0.23 * age) - 5.4
# LBM
lbm = weight_kg * (1 - (bfp / 100))
# BMI Category (WHO)
if bmi < 18.5:
bmi_category = "Underweight"
elif bmi <= 24.9:
bmi_category = "Normal"
elif bmi <= 29.9:
bmi_category = "Overweight"
else:
bmi_category = "Obese"
# BFP Category (ACE)
if sex.lower() == "female":
if bfp <= 13:
bfp_category = "Essential fat"
elif bfp <= 20:
bfp_category = "Athlete"
elif bfp <= 24:
bfp_category = "Fit"
elif bfp <= 31:
bfp_category = "Average"
else:
bfp_category = "Obese"
else:
if bfp <= 5:
bfp_category = "Essential fat"
elif bfp <= 13:
bfp_category = "Athlete"
elif bfp <= 17:
bfp_category = "Fit"
elif bfp <= 24:
bfp_category = "Average"
else:
bfp_category = "Obese"
# Interpretation
if bfp_category=="Essential":
interpretation='Minimum fat required for basic physiological functions (e.g., hormone production, insulation). Females require higher essential fat due to reproductive functions.'
elif bfp_category=='Athletes':
interpretation='Typical for competitive athletes with high muscle mass and low fat (e.g., runners, bodybuilders).'
elif bfp_category=='Fit'
interpretation='Healthy range for active individuals who exercise regularly but arenโ€™t competitive athletes.'
elif bfp_category=='Average'
interpretation='Common for the general population, still within healthy limits'
elif bfp_category=='Obese'
interpretation='Associated with increased health risks (e.g., diabetes, heart disease and other CVDs)'
DataAll
return {
"BMI": round(bmi, 2),
"WHTR": round(whtr, 2),
"BFP": round(bfp, 2),
"LBM": round(lbm, 2),
"BMI_Category": bmi_category,
"BFP_Category": bfp_category,
"Interpretation": interpretation
}
# Prediction function
device = "cuda" if torch.cuda.is_available() else "cpu"
def get_prediction(age, sex, height_cm, weight_kg, wc_cm):
# Calculate metrics
metrics = calculate_metrics(age, sex, height_cm, weight_kg, wc_cm)
# Create prompt with metrics
prompt = (
f"Age: {age}, Sex: {sex}, Height: {height_cm} cm, Weight: {weight_kg} kg, WC: {wc_cm} cm\n"
f"BMI: {metrics['BMI']} kg/m2, WHTR: {metrics['WHTR']} m, BFP: {metrics['BFP']}%, "
f"LBM: {metrics['LBM']} kg, BMI Category: {metrics['BMI_Category']}, "
f"BFP Category: {metrics['BFP_Category']}\n"
f"Provide an interpretation of these anthropometric metrics.\n###"
)
st.write(f"Received prompt: {prompt}")
# Create message structure
messages = [{"role": "user", "content": prompt}]
# Tokenize the input
try:
inputs = tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt",
max_length=512,
truncation=True
)
except Exception as e:
st.warning(f"apply_chat_template failed: {str(e)}. Falling back to manual tokenization.")
inputs = tokenizer(
prompt,
return_tensors="pt",
max_length=512,
truncation=True,
padding=False
)
# Debug: Log inputs structure
st.write(f"Inputs type: {type(inputs)}")
# Handle inputs (tensor or dict)
if isinstance(inputs, torch.Tensor):
input_ids = inputs
if len(input_ids.shape) == 1:
input_ids = input_ids.unsqueeze(0) # [sequence_length] -> [1, sequence_length]
elif len(input_ids.shape) > 2:
input_ids = input_ids.squeeze()
if len(input_ids.shape) == 1:
input_ids = input_ids.unsqueeze(0)
elif isinstance(inputs, dict) and 'input_ids' in inputs:
input_ids = inputs['input_ids']
if len(input_ids.shape) == 3 and input_ids.shape[0] == 1:
input_ids = input_ids.squeeze(0)
elif len(input_ids.shape) == 1:
input_ids = input_ids.unsqueeze(0)
else:
st.error(f"Unexpected inputs format: {type(inputs)}")
return None
st.write(f"Input IDs shape: {input_ids.shape}")
# Ensure input_ids is on the correct device
input_ids = input_ids.to(device)
# Generate output
try:
output = model.generate(
input_ids=input_ids,
max_new_tokens=150,
temperature=0.7,
top_p=0.95,
do_sample=True,
pad_token_id=tokenizer.pad_token_id
)
except Exception as e:
st.error(f"Error during generation: {str(e)}")
return None
# Decode the output
try:
decoded = tokenizer.decode(output[0], skip_special_tokens=False) # Preserve special tokens
st.write(f"Decoded output: {decoded}")
return decoded
except Exception as e:
st.error(f"Error decoding output: {str(e)}")
return None
# UI Header
st.title("๐Ÿง  AnthroBot")
st.write("Enter your anthropometric estimates to receive an interpreted summary โ€” manually or via CSV upload.")
# Tabs for input method
tab1, tab2 = st.tabs(["๐Ÿง Manual Input", "๐Ÿ“„ CSV Upload"])
with tab1:
st.subheader("Manual Entry")
age = st.number_input("Age", 0, 100, 16)
sex = st.selectbox("Sex", ["male", "female"], index=1)
height = st.number_input("Height (cm)", 100.0, 250.0, 153.0)
weight = st.number_input("Weight (kg)", 30.0, 200.0, 51.1)
wc = st.number_input("Waist Circumference (cm)", 30.0, 150.0, 64.0)
if st.button("Get Prediction"):
prediction = get_prediction(age, sex, height, weight, wc)
if prediction:
st.success("Prediction:")
st.write(prediction)
with tab2:
st.subheader("Batch Upload via CSV")
sample_csv = pd.DataFrame({
"Age": [16],
"Sex": ["female"],
"Height": [153.0],
"Weight": [51.1],
"WC": [64.0]
})
st.download_button("๐Ÿ“ฅ Download Sample CSV", sample_csv.to_csv(index=False), file_name="sample_input.csv")
uploaded_file = st.file_uploader("Upload your CSV file", type=["csv"])
if uploaded_file:
df = pd.read_csv(uploaded_file)
if not all(col in df.columns for col in ["Age", "Sex", "Height", "Weight", "WC"]):
st.error("CSV must contain columns: Age, Sex, Height, Weight, WC")
else:
outputs = []
with st.spinner("Generating predictions..."):
for _, row in df.iterrows():
prediction = get_prediction(row['Age'], row['Sex'], row['Height'], row['Weight'], row['WC'])
outputs.append(prediction if prediction else "Error")
df["Prediction"] = outputs
st.success("Here are your predictions:")
st.dataframe(df)
csv_output = df.to_csv(index=False).encode("utf-8")
st.download_button("๐Ÿ“ค Download Predictions", data=csv_output, file_name="predictions.csv")