SallySims commited on
Commit
75b1eec
·
verified ·
1 Parent(s): 760ef09

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -18
app.py CHANGED
@@ -4,6 +4,7 @@ import streamlit as st
4
  import pandas as pd
5
  import torch
6
  import os
 
7
  from huggingface_hub import login
8
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
9
  from peft import PeftModel, PeftConfig
@@ -54,6 +55,18 @@ model, tokenizer = load_model()
54
  if 'history' not in st.session_state:
55
  st.session_state.history = []
56
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  # Prediction function
58
  device = "cuda" if torch.cuda.is_available() else "cpu"
59
 
@@ -123,34 +136,51 @@ def generate_response(age, sex, height_cm, weight_kg, wc_cm):
123
 
124
  # Generate output
125
  st.write("🤖 Model response:")
126
- with st.empty():
127
- text_streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False)
128
- output = model.generate(
129
- input_ids=input_ids,
130
- attention_mask=attention_mask,
131
- max_new_tokens=250,
132
- temperature=0.7,
133
- top_p=0.95,
134
- do_sample=True,
135
- pad_token_id=tokenizer.eos_token_id,
136
- use_cache=True,
137
- streamer=text_streamer
138
- )
139
 
140
  # Decode the output
141
  decoded = tokenizer.decode(output[0], skip_special_tokens=False)
142
  st.write(f"Decoded output: {decoded}")
143
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  # Update history
145
- st.session_state.history.append((prompt, decoded))
146
- return decoded
 
 
 
147
 
148
  except Exception as e:
149
  st.error(f"Error during generation: {str(e)}")
150
  return None
151
 
152
  # UI Header
153
- st.title("🧠 AnthroBot")
154
  st.markdown("Enter your anthropometric details to receive an AI-generated summary of health metrics.")
155
 
156
  # Tabs for input method
@@ -211,5 +241,4 @@ with tab2:
211
  # Clear history button
212
  if st.button("Clear History"):
213
  st.session_state.history = []
214
- st.rerun()
215
-
 
4
  import pandas as pd
5
  import torch
6
  import os
7
+ import re
8
  from huggingface_hub import login
9
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
10
  from peft import PeftModel, PeftConfig
 
55
  if 'history' not in st.session_state:
56
  st.session_state.history = []
57
 
58
+ # Placeholder MET prediction function (replace with actual model)
59
+ def predict_met(age, sex, weight_kg, bfp):
60
+ try:
61
+ # Simulated MET prediction (replace with your model)
62
+ # Example: Linear regression based on AntDatpromt.csv features
63
+ lbm = weight_kg * (1 - (bfp / 100))
64
+ met = 3.5 + 0.1 * lbm - 0.05 * age + (0.3 if sex == "male" else 0.0) # Placeholder formula
65
+ return round(met, 2)
66
+ except Exception as e:
67
+ st.warning(f"MET prediction error: {str(e)}")
68
+ return None
69
+
70
  # Prediction function
71
  device = "cuda" if torch.cuda.is_available() else "cpu"
72
 
 
136
 
137
  # Generate output
138
  st.write("🤖 Model response:")
139
+ output_container = st.empty()
140
+ text_streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
141
+ output = model.generate(
142
+ input_ids=input_ids,
143
+ attention_mask=attention_mask,
144
+ max_new_tokens=250,
145
+ temperature=0.7,
146
+ top_p=0.95,
147
+ do_sample=True,
148
+ pad_token_id=tokenizer.eos_token_id,
149
+ use_cache=True,
150
+ streamer=text_streamer
151
+ )
152
 
153
  # Decode the output
154
  decoded = tokenizer.decode(output[0], skip_special_tokens=False)
155
  st.write(f"Decoded output: {decoded}")
156
 
157
+ # Extract assistant response (remove system header and tokens)
158
+ assistant_response = re.search(r"<|start_header_id|>assistant<|end_header_id>\n(.*)<|eot_id>", decoded, re.DOTALL)
159
+ clean_response = assistant_response.group(1).strip() if assistant_response else decoded
160
+
161
+ # MET prediction
162
+ try:
163
+ bfp_match = re.search(r"BFP:\s*(\d+\.\d+)%", clean_response)
164
+ bfp = float(bfp_match.group(1)) if bfp_match else 18.36 # Fallback
165
+ met_pred = predict_met(age, sex, weight_kg, bfp)
166
+ if met_pred is not None:
167
+ clean_response += f", Predicted MET: {met_pred}"
168
+ except Exception as e:
169
+ st.warning(f"MET prediction skipped: {str(e)}")
170
+
171
  # Update history
172
+ st.session_state.history.append((prompt, clean_response))
173
+
174
+ # Display clean response
175
+ output_container.write(clean_response)
176
+ return clean_response
177
 
178
  except Exception as e:
179
  st.error(f"Error during generation: {str(e)}")
180
  return None
181
 
182
  # UI Header
183
+ st.title("🧠 Health Metric Estimator")
184
  st.markdown("Enter your anthropometric details to receive an AI-generated summary of health metrics.")
185
 
186
  # Tabs for input method
 
241
  # Clear history button
242
  if st.button("Clear History"):
243
  st.session_state.history = []
244
+ st.rerun()