| |
| import streamlit as st |
| import pandas as pd |
| import torch |
| import os |
| from huggingface_hub import login |
| from transformers import AutoTokenizer, AutoModelForCausalLM |
| from peft import PeftModel, PeftConfig |
| import io |
|
|
| |
| login(token=os.getenv("HUGGINGFACEHUB_TOKEN")) |
|
|
| st.set_page_config(page_title="AnthroBot", page_icon="๐ค", layout="centered") |
|
|
| |
| @st.cache_resource |
| def load_model(): |
| try: |
| peft_config = PeftConfig.from_pretrained("SallySims/AnthroBot_Model_Lora") |
| base_model = AutoModelForCausalLM.from_pretrained( |
| peft_config.base_model_name_or_path, |
| torch_dtype=torch.float16, |
| device_map="auto" |
| ) |
| model = PeftModel.from_pretrained(base_model, "SallySims/AnthroBot_Model_Lora") |
| model.eval() |
|
|
| tokenizer = AutoTokenizer.from_pretrained(peft_config.base_model_name_or_path) |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| st.write("โ
Model and tokenizer loaded successfully.") |
| return model, tokenizer |
| except Exception as e: |
| st.error(f"Error loading model: {str(e)}") |
| raise e |
|
|
| model, tokenizer = load_model() |
|
|
| |
| def calculate_metrics(age, sex, height_cm, weight_kg, wc_cm): |
| |
| height_m = height_cm / 100 |
| wc_m = wc_cm / 100 |
| |
| |
| bmi = weight_kg / (height_m ** 2) |
| |
| |
| whtr = wc_m / height_m |
| |
| |
| if sex.lower() == "male": |
| bfp = (1.20 * bmi) + (0.23 * age) - 16.2 |
| else: |
| bfp = (1.20 * bmi) + (0.23 * age) - 5.4 |
| |
| |
| lbm = weight_kg * (1 - (bfp / 100)) |
| |
| |
| if bmi < 18.5: |
| bmi_category = "Underweight" |
| elif bmi <= 24.9: |
| bmi_category = "Normal" |
| elif bmi <= 29.9: |
| bmi_category = "Overweight" |
| else: |
| bmi_category = "Obese" |
| |
| |
| if sex.lower() == "female": |
| if bfp <= 13: |
| bfp_category = "Essential fat" |
| elif bfp <= 20: |
| bfp_category = "Athlete" |
| elif bfp <= 24: |
| bfp_category = "Fit" |
| elif bfp <= 31: |
| bfp_category = "Average" |
| else: |
| bfp_category = "Obese" |
| else: |
| if bfp <= 5: |
| bfp_category = "Essential fat" |
| elif bfp <= 13: |
| bfp_category = "Athlete" |
| elif bfp <= 17: |
| bfp_category = "Fit" |
| elif bfp <= 24: |
| bfp_category = "Average" |
| else: |
| bfp_category = "Obese" |
|
|
| |
| if bfp_category=="Essential": |
| interpretation='Minimum fat required for basic physiological functions (e.g., hormone production, insulation). Females require higher essential fat due to reproductive functions.' |
| elif bfp_category=='Athletes': |
| interpretation='Typical for competitive athletes with high muscle mass and low fat (e.g., runners, bodybuilders).' |
| elif bfp_category=='Fit' |
| interpretation='Healthy range for active individuals who exercise regularly but arenโt competitive athletes.' |
| elif bfp_category=='Average' |
| interpretation='Common for the general population, still within healthy limits' |
| elif bfp_category=='Obese' |
| interpretation='Associated with increased health risks (e.g., diabetes, heart disease and other CVDs)' |
|
|
|
|
| DataAll |
| return { |
| "BMI": round(bmi, 2), |
| "WHTR": round(whtr, 2), |
| "BFP": round(bfp, 2), |
| "LBM": round(lbm, 2), |
| "BMI_Category": bmi_category, |
| "BFP_Category": bfp_category, |
| "Interpretation": interpretation |
| } |
|
|
| |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| def get_prediction(age, sex, height_cm, weight_kg, wc_cm): |
| |
| metrics = calculate_metrics(age, sex, height_cm, weight_kg, wc_cm) |
| |
| |
| prompt = ( |
| f"Age: {age}, Sex: {sex}, Height: {height_cm} cm, Weight: {weight_kg} kg, WC: {wc_cm} cm\n" |
| f"BMI: {metrics['BMI']} kg/m2, WHTR: {metrics['WHTR']} m, BFP: {metrics['BFP']}%, " |
| f"LBM: {metrics['LBM']} kg, BMI Category: {metrics['BMI_Category']}, " |
| f"BFP Category: {metrics['BFP_Category']}\n" |
| f"Provide an interpretation of these anthropometric metrics.\n###" |
| ) |
| st.write(f"Received prompt: {prompt}") |
|
|
| |
| messages = [{"role": "user", "content": prompt}] |
| |
| |
| try: |
| inputs = tokenizer.apply_chat_template( |
| messages, |
| tokenize=True, |
| add_generation_prompt=True, |
| return_tensors="pt", |
| max_length=512, |
| truncation=True |
| ) |
| except Exception as e: |
| st.warning(f"apply_chat_template failed: {str(e)}. Falling back to manual tokenization.") |
| inputs = tokenizer( |
| prompt, |
| return_tensors="pt", |
| max_length=512, |
| truncation=True, |
| padding=False |
| ) |
| |
| |
| st.write(f"Inputs type: {type(inputs)}") |
| |
| |
| if isinstance(inputs, torch.Tensor): |
| input_ids = inputs |
| if len(input_ids.shape) == 1: |
| input_ids = input_ids.unsqueeze(0) |
| elif len(input_ids.shape) > 2: |
| input_ids = input_ids.squeeze() |
| if len(input_ids.shape) == 1: |
| input_ids = input_ids.unsqueeze(0) |
| elif isinstance(inputs, dict) and 'input_ids' in inputs: |
| input_ids = inputs['input_ids'] |
| if len(input_ids.shape) == 3 and input_ids.shape[0] == 1: |
| input_ids = input_ids.squeeze(0) |
| elif len(input_ids.shape) == 1: |
| input_ids = input_ids.unsqueeze(0) |
| else: |
| st.error(f"Unexpected inputs format: {type(inputs)}") |
| return None |
|
|
| st.write(f"Input IDs shape: {input_ids.shape}") |
|
|
| |
| input_ids = input_ids.to(device) |
|
|
| |
| try: |
| output = model.generate( |
| input_ids=input_ids, |
| max_new_tokens=150, |
| temperature=0.7, |
| top_p=0.95, |
| do_sample=True, |
| pad_token_id=tokenizer.pad_token_id |
| ) |
| except Exception as e: |
| st.error(f"Error during generation: {str(e)}") |
| return None |
|
|
| |
| try: |
| decoded = tokenizer.decode(output[0], skip_special_tokens=False) |
| st.write(f"Decoded output: {decoded}") |
| return decoded |
| except Exception as e: |
| st.error(f"Error decoding output: {str(e)}") |
| return None |
|
|
| |
| st.title("๐ง AnthroBot") |
| st.write("Enter your anthropometric estimates to receive an interpreted summary โ manually or via CSV upload.") |
|
|
| |
| tab1, tab2 = st.tabs(["๐ง Manual Input", "๐ CSV Upload"]) |
|
|
| with tab1: |
| st.subheader("Manual Entry") |
| age = st.number_input("Age", 0, 100, 16) |
| sex = st.selectbox("Sex", ["male", "female"], index=1) |
| height = st.number_input("Height (cm)", 100.0, 250.0, 153.0) |
| weight = st.number_input("Weight (kg)", 30.0, 200.0, 51.1) |
| wc = st.number_input("Waist Circumference (cm)", 30.0, 150.0, 64.0) |
|
|
| if st.button("Get Prediction"): |
| prediction = get_prediction(age, sex, height, weight, wc) |
| if prediction: |
| st.success("Prediction:") |
| st.write(prediction) |
|
|
| with tab2: |
| st.subheader("Batch Upload via CSV") |
| sample_csv = pd.DataFrame({ |
| "Age": [16], |
| "Sex": ["female"], |
| "Height": [153.0], |
| "Weight": [51.1], |
| "WC": [64.0] |
| }) |
|
|
| st.download_button("๐ฅ Download Sample CSV", sample_csv.to_csv(index=False), file_name="sample_input.csv") |
|
|
| uploaded_file = st.file_uploader("Upload your CSV file", type=["csv"]) |
|
|
| if uploaded_file: |
| df = pd.read_csv(uploaded_file) |
| if not all(col in df.columns for col in ["Age", "Sex", "Height", "Weight", "WC"]): |
| st.error("CSV must contain columns: Age, Sex, Height, Weight, WC") |
| else: |
| outputs = [] |
| with st.spinner("Generating predictions..."): |
| for _, row in df.iterrows(): |
| prediction = get_prediction(row['Age'], row['Sex'], row['Height'], row['Weight'], row['WC']) |
| outputs.append(prediction if prediction else "Error") |
|
|
| df["Prediction"] = outputs |
| st.success("Here are your predictions:") |
| st.dataframe(df) |
|
|
| csv_output = df.to_csv(index=False).encode("utf-8") |
| st.download_button("๐ค Download Predictions", data=csv_output, file_name="predictions.csv") |