Update app.py
Browse files
app.py
CHANGED
|
@@ -4,6 +4,7 @@ import streamlit as st
|
|
| 4 |
import pandas as pd
|
| 5 |
import torch
|
| 6 |
import os
|
|
|
|
| 7 |
from huggingface_hub import login
|
| 8 |
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
|
| 9 |
from peft import PeftModel, PeftConfig
|
|
@@ -54,6 +55,18 @@ model, tokenizer = load_model()
|
|
| 54 |
if 'history' not in st.session_state:
|
| 55 |
st.session_state.history = []
|
| 56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
# Prediction function
|
| 58 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 59 |
|
|
@@ -123,34 +136,51 @@ def generate_response(age, sex, height_cm, weight_kg, wc_cm):
|
|
| 123 |
|
| 124 |
# Generate output
|
| 125 |
st.write("🤖 Model response:")
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
|
| 140 |
# Decode the output
|
| 141 |
decoded = tokenizer.decode(output[0], skip_special_tokens=False)
|
| 142 |
st.write(f"Decoded output: {decoded}")
|
| 143 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
# Update history
|
| 145 |
-
st.session_state.history.append((prompt,
|
| 146 |
-
|
|
|
|
|
|
|
|
|
|
| 147 |
|
| 148 |
except Exception as e:
|
| 149 |
st.error(f"Error during generation: {str(e)}")
|
| 150 |
return None
|
| 151 |
|
| 152 |
# UI Header
|
| 153 |
-
st.title("🧠
|
| 154 |
st.markdown("Enter your anthropometric details to receive an AI-generated summary of health metrics.")
|
| 155 |
|
| 156 |
# Tabs for input method
|
|
@@ -211,5 +241,4 @@ with tab2:
|
|
| 211 |
# Clear history button
|
| 212 |
if st.button("Clear History"):
|
| 213 |
st.session_state.history = []
|
| 214 |
-
st.rerun()
|
| 215 |
-
|
|
|
|
| 4 |
import pandas as pd
|
| 5 |
import torch
|
| 6 |
import os
|
| 7 |
+
import re
|
| 8 |
from huggingface_hub import login
|
| 9 |
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
|
| 10 |
from peft import PeftModel, PeftConfig
|
|
|
|
| 55 |
if 'history' not in st.session_state:
|
| 56 |
st.session_state.history = []
|
| 57 |
|
| 58 |
+
# Placeholder MET prediction function (replace with actual model)
|
| 59 |
+
def predict_met(age, sex, weight_kg, bfp):
|
| 60 |
+
try:
|
| 61 |
+
# Simulated MET prediction (replace with your model)
|
| 62 |
+
# Example: Linear regression based on AntDatpromt.csv features
|
| 63 |
+
lbm = weight_kg * (1 - (bfp / 100))
|
| 64 |
+
met = 3.5 + 0.1 * lbm - 0.05 * age + (0.3 if sex == "male" else 0.0) # Placeholder formula
|
| 65 |
+
return round(met, 2)
|
| 66 |
+
except Exception as e:
|
| 67 |
+
st.warning(f"MET prediction error: {str(e)}")
|
| 68 |
+
return None
|
| 69 |
+
|
| 70 |
# Prediction function
|
| 71 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 72 |
|
|
|
|
| 136 |
|
| 137 |
# Generate output
|
| 138 |
st.write("🤖 Model response:")
|
| 139 |
+
output_container = st.empty()
|
| 140 |
+
text_streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
| 141 |
+
output = model.generate(
|
| 142 |
+
input_ids=input_ids,
|
| 143 |
+
attention_mask=attention_mask,
|
| 144 |
+
max_new_tokens=250,
|
| 145 |
+
temperature=0.7,
|
| 146 |
+
top_p=0.95,
|
| 147 |
+
do_sample=True,
|
| 148 |
+
pad_token_id=tokenizer.eos_token_id,
|
| 149 |
+
use_cache=True,
|
| 150 |
+
streamer=text_streamer
|
| 151 |
+
)
|
| 152 |
|
| 153 |
# Decode the output
|
| 154 |
decoded = tokenizer.decode(output[0], skip_special_tokens=False)
|
| 155 |
st.write(f"Decoded output: {decoded}")
|
| 156 |
|
| 157 |
+
# Extract assistant response (remove system header and tokens)
|
| 158 |
+
assistant_response = re.search(r"<|start_header_id|>assistant<|end_header_id>\n(.*)<|eot_id>", decoded, re.DOTALL)
|
| 159 |
+
clean_response = assistant_response.group(1).strip() if assistant_response else decoded
|
| 160 |
+
|
| 161 |
+
# MET prediction
|
| 162 |
+
try:
|
| 163 |
+
bfp_match = re.search(r"BFP:\s*(\d+\.\d+)%", clean_response)
|
| 164 |
+
bfp = float(bfp_match.group(1)) if bfp_match else 18.36 # Fallback
|
| 165 |
+
met_pred = predict_met(age, sex, weight_kg, bfp)
|
| 166 |
+
if met_pred is not None:
|
| 167 |
+
clean_response += f", Predicted MET: {met_pred}"
|
| 168 |
+
except Exception as e:
|
| 169 |
+
st.warning(f"MET prediction skipped: {str(e)}")
|
| 170 |
+
|
| 171 |
# Update history
|
| 172 |
+
st.session_state.history.append((prompt, clean_response))
|
| 173 |
+
|
| 174 |
+
# Display clean response
|
| 175 |
+
output_container.write(clean_response)
|
| 176 |
+
return clean_response
|
| 177 |
|
| 178 |
except Exception as e:
|
| 179 |
st.error(f"Error during generation: {str(e)}")
|
| 180 |
return None
|
| 181 |
|
| 182 |
# UI Header
|
| 183 |
+
st.title("🧠 Health Metric Estimator")
|
| 184 |
st.markdown("Enter your anthropometric details to receive an AI-generated summary of health metrics.")
|
| 185 |
|
| 186 |
# Tabs for input method
|
|
|
|
| 241 |
# Clear history button
|
| 242 |
if st.button("Clear History"):
|
| 243 |
st.session_state.history = []
|
| 244 |
+
st.rerun()
|
|
|