SallySims commited on
Commit
efec93f
·
verified ·
1 Parent(s): c8450fe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -24
app.py CHANGED
@@ -8,7 +8,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
8
  from peft import PeftModel, PeftConfig
9
  import io
10
 
11
- # Login using Hugging Face token stored in Space secrets
12
  login(token=os.getenv("HUGGINGFACEHUB_TOKEN"))
13
 
14
  st.set_page_config(page_title="AnthroBot", page_icon="🤖", layout="centered")
@@ -37,13 +37,88 @@ def load_model():
37
 
38
  model, tokenizer = load_model()
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  # Prediction function
41
  device = "cuda" if torch.cuda.is_available() else "cpu"
42
 
43
- def get_prediction(prompt):
 
 
 
 
 
 
 
 
 
 
 
44
  st.write(f"Received prompt: {prompt}")
45
 
46
- # Create a message structure
47
  messages = [{"role": "user", "content": prompt}]
48
 
49
  # Tokenize the input
@@ -58,7 +133,6 @@ def get_prediction(prompt):
58
  )
59
  except Exception as e:
60
  st.warning(f"apply_chat_template failed: {str(e)}. Falling back to manual tokenization.")
61
- # Fallback: Manual tokenization
62
  inputs = tokenizer(
63
  prompt,
64
  return_tensors="pt",
@@ -72,12 +146,11 @@ def get_prediction(prompt):
72
 
73
  # Handle inputs (tensor or dict)
74
  if isinstance(inputs, torch.Tensor):
75
- # Direct tensor (likely input_ids)
76
  input_ids = inputs
77
  if len(input_ids.shape) == 1:
78
- input_ids = input_ids.unsqueeze(0) # Add batch dimension: [sequence_length] -> [1, sequence_length]
79
  elif len(input_ids.shape) > 2:
80
- input_ids = input_ids.squeeze() # Remove extra dimensions if any
81
  if len(input_ids.shape) == 1:
82
  input_ids = input_ids.unsqueeze(0)
83
  elif isinstance(inputs, dict) and 'input_ids' in inputs:
@@ -111,7 +184,7 @@ def get_prediction(prompt):
111
 
112
  # Decode the output
113
  try:
114
- decoded = tokenizer.decode(output[0], skip_special_tokens=True)
115
  st.write(f"Decoded output: {decoded}")
116
  return decoded
117
  except Exception as e:
@@ -127,15 +200,14 @@ tab1, tab2 = st.tabs(["🧍 Manual Input", "📄 CSV Upload"])
127
 
128
  with tab1:
129
  st.subheader("Manual Entry")
130
- age = st.number_input("Age", 0, 100, 30)
131
- sex = st.selectbox("Sex", ["male", "female"])
132
- height = st.number_input("Height (cm)", 100.0, 250.0, 150.5)
133
- weight = st.number_input("Weight (kg)", 30.0, 200.0, 75.3)
134
- wc = st.number_input("Waist Circumference (cm)", 30.0, 150.0, 68.0)
135
 
136
  if st.button("Get Prediction"):
137
- prompt = f"Age: {age}, Sex: {sex}, Height: {height} cm, Weight: {weight} kg, WC: {wc} cm\n\n###"
138
- prediction = get_prediction(prompt)
139
  if prediction:
140
  st.success("Prediction:")
141
  st.write(prediction)
@@ -143,11 +215,11 @@ with tab1:
143
  with tab2:
144
  st.subheader("Batch Upload via CSV")
145
  sample_csv = pd.DataFrame({
146
- "Age": [30],
147
  "Sex": ["female"],
148
- "Height": [150.5],
149
- "Weight": [75.3],
150
- "WC": [68.0]
151
  })
152
 
153
  st.download_button("📥 Download Sample CSV", sample_csv.to_csv(index=False), file_name="sample_input.csv")
@@ -162,11 +234,7 @@ with tab2:
162
  outputs = []
163
  with st.spinner("Generating predictions..."):
164
  for _, row in df.iterrows():
165
- prompt = (
166
- f"Age: {row['Age']}, Sex: {row['Sex']}, Height: {row['Height']} cm, "
167
- f"Weight: {row['Weight']} kg, WC: {row['WC']} cm\n\n###"
168
- )
169
- prediction = get_prediction(prompt)
170
  outputs.append(prediction if prediction else "Error")
171
 
172
  df["Prediction"] = outputs
 
8
  from peft import PeftModel, PeftConfig
9
  import io
10
 
11
+ # Login using Hugging Face token
12
  login(token=os.getenv("HUGGINGFACEHUB_TOKEN"))
13
 
14
  st.set_page_config(page_title="AnthroBot", page_icon="🤖", layout="centered")
 
37
 
38
  model, tokenizer = load_model()
39
 
40
+ # Calculate anthropometric metrics
41
+ def calculate_metrics(age, sex, height_cm, weight_kg, wc_cm):
42
+ # Convert height and WC to meters
43
+ height_m = height_cm / 100
44
+ wc_m = wc_cm / 100
45
+
46
+ # BMI
47
+ bmi = weight_kg / (height_m ** 2)
48
+
49
+ # WHTR
50
+ whtr = wc_m / height_m
51
+
52
+ # BFP (Boer's formula approximation)
53
+ if sex.lower() == "male":
54
+ bfp = (1.20 * bmi) + (0.23 * age) - 16.2
55
+ else:
56
+ bfp = (1.20 * bmi) + (0.23 * age) - 5.4
57
+
58
+ # LBM
59
+ lbm = weight_kg * (1 - (bfp / 100))
60
+
61
+ # BMI Category (WHO)
62
+ if bmi < 18.5:
63
+ bmi_category = "Underweight"
64
+ elif bmi <= 24.9:
65
+ bmi_category = "Normal"
66
+ elif bmi <= 29.9:
67
+ bmi_category = "Overweight"
68
+ else:
69
+ bmi_category = "Obese"
70
+
71
+ # BFP Category (ACE)
72
+ if sex.lower() == "female":
73
+ if bfp <= 13:
74
+ bfp_category = "Essential"
75
+ elif bfp <= 20:
76
+ bfp_category = "Athlete"
77
+ elif bfp <= 24:
78
+ bfp_category = "Fitness"
79
+ elif bfp <= 31:
80
+ bfp_category = "Average"
81
+ else:
82
+ bfp_category = "Obese"
83
+ else:
84
+ if bfp <= 5:
85
+ bfp_category = "Essential"
86
+ elif bfp <= 13:
87
+ bfp_category = "Athlete"
88
+ elif bfp <= 17:
89
+ bfp_category = "Fitness"
90
+ elif bfp <= 24:
91
+ bfp_category = "Average"
92
+ else:
93
+ bfp_category = "Obese"
94
+
95
+ return {
96
+ "BMI": round(bmi, 2),
97
+ "WHTR": round(whtr, 2),
98
+ "BFP": round(bfp, 2),
99
+ "LBM": round(lbm, 2),
100
+ "BMI_Category": bmi_category,
101
+ "BFP_Category": bfp_category
102
+ }
103
+
104
  # Prediction function
105
  device = "cuda" if torch.cuda.is_available() else "cpu"
106
 
107
+ def get_prediction(age, sex, height_cm, weight_kg, wc_cm):
108
+ # Calculate metrics
109
+ metrics = calculate_metrics(age, sex, height_cm, weight_kg, wc_cm)
110
+
111
+ # Create prompt with metrics
112
+ prompt = (
113
+ f"Age: {age}, Sex: {sex}, Height: {height_cm} cm, Weight: {weight_kg} kg, WC: {wc_cm} cm\n"
114
+ f"BMI: {metrics['BMI']} kg/m2, WHTR: {metrics['WHTR']} m, BFP: {metrics['BFP']}%, "
115
+ f"LBM: {metrics['LBM']} kg, BMI Category: {metrics['BMI_Category']}, "
116
+ f"BFP Category: {metrics['BFP_Category']}\n"
117
+ f"Provide an interpretation of these anthropometric metrics.\n###"
118
+ )
119
  st.write(f"Received prompt: {prompt}")
120
 
121
+ # Create message structure
122
  messages = [{"role": "user", "content": prompt}]
123
 
124
  # Tokenize the input
 
133
  )
134
  except Exception as e:
135
  st.warning(f"apply_chat_template failed: {str(e)}. Falling back to manual tokenization.")
 
136
  inputs = tokenizer(
137
  prompt,
138
  return_tensors="pt",
 
146
 
147
  # Handle inputs (tensor or dict)
148
  if isinstance(inputs, torch.Tensor):
 
149
  input_ids = inputs
150
  if len(input_ids.shape) == 1:
151
+ input_ids = input_ids.unsqueeze(0) # [sequence_length] -> [1, sequence_length]
152
  elif len(input_ids.shape) > 2:
153
+ input_ids = input_ids.squeeze()
154
  if len(input_ids.shape) == 1:
155
  input_ids = input_ids.unsqueeze(0)
156
  elif isinstance(inputs, dict) and 'input_ids' in inputs:
 
184
 
185
  # Decode the output
186
  try:
187
+ decoded = tokenizer.decode(output[0], skip_special_tokens=False) # Preserve special tokens
188
  st.write(f"Decoded output: {decoded}")
189
  return decoded
190
  except Exception as e:
 
200
 
201
  with tab1:
202
  st.subheader("Manual Entry")
203
+ age = st.number_input("Age", 0, 100, 16)
204
+ sex = st.selectbox("Sex", ["male", "female"], index=1)
205
+ height = st.number_input("Height (cm)", 100.0, 250.0, 153.0)
206
+ weight = st.number_input("Weight (kg)", 30.0, 200.0, 51.1)
207
+ wc = st.number_input("Waist Circumference (cm)", 30.0, 150.0, 64.0)
208
 
209
  if st.button("Get Prediction"):
210
+ prediction = get_prediction(age, sex, height, weight, wc)
 
211
  if prediction:
212
  st.success("Prediction:")
213
  st.write(prediction)
 
215
  with tab2:
216
  st.subheader("Batch Upload via CSV")
217
  sample_csv = pd.DataFrame({
218
+ "Age": [16],
219
  "Sex": ["female"],
220
+ "Height": [153.0],
221
+ "Weight": [51.1],
222
+ "WC": [64.0]
223
  })
224
 
225
  st.download_button("📥 Download Sample CSV", sample_csv.to_csv(index=False), file_name="sample_input.csv")
 
234
  outputs = []
235
  with st.spinner("Generating predictions..."):
236
  for _, row in df.iterrows():
237
+ prediction = get_prediction(row['Age'], row['Sex'], row['Height'], row['Weight'], row['WC'])
 
 
 
 
238
  outputs.append(prediction if prediction else "Error")
239
 
240
  df["Prediction"] = outputs