Update app.py
Browse files
app.py
CHANGED
|
@@ -4,7 +4,6 @@ import streamlit as st
|
|
| 4 |
import pandas as pd
|
| 5 |
import torch
|
| 6 |
import os
|
| 7 |
-
import re
|
| 8 |
from huggingface_hub import login
|
| 9 |
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
|
| 10 |
from peft import PeftModel, PeftConfig
|
|
@@ -55,18 +54,6 @@ model, tokenizer = load_model()
|
|
| 55 |
if 'history' not in st.session_state:
|
| 56 |
st.session_state.history = []
|
| 57 |
|
| 58 |
-
# Placeholder MET prediction function (replace with actual model)
|
| 59 |
-
def predict_met(age, sex, weight_kg, bfp):
|
| 60 |
-
try:
|
| 61 |
-
# Simulated MET prediction (replace with your model)
|
| 62 |
-
# Example: Linear regression based on AntDatpromt.csv features
|
| 63 |
-
lbm = weight_kg * (1 - (bfp / 100))
|
| 64 |
-
met = 3.5 + 0.1 * lbm - 0.05 * age + (0.3 if sex == "male" else 0.0) # Placeholder formula
|
| 65 |
-
return round(met, 2)
|
| 66 |
-
except Exception as e:
|
| 67 |
-
st.warning(f"MET prediction error: {str(e)}")
|
| 68 |
-
return None
|
| 69 |
-
|
| 70 |
# Prediction function
|
| 71 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 72 |
|
|
@@ -136,44 +123,27 @@ def generate_response(age, sex, height_cm, weight_kg, wc_cm):
|
|
| 136 |
|
| 137 |
# Generate output
|
| 138 |
st.write("🤖 Model response:")
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
|
| 153 |
# Decode the output
|
| 154 |
decoded = tokenizer.decode(output[0], skip_special_tokens=False)
|
| 155 |
st.write(f"Decoded output: {decoded}")
|
| 156 |
|
| 157 |
-
# Extract assistant response (remove system header and tokens)
|
| 158 |
-
assistant_response = re.search(r"<|start_header_id|>assistant<|end_header_id>\n(.*)<|eot_id>", decoded, re.DOTALL)
|
| 159 |
-
clean_response = assistant_response.group(1).strip() if assistant_response else decoded
|
| 160 |
-
|
| 161 |
-
# MET prediction
|
| 162 |
-
try:
|
| 163 |
-
bfp_match = re.search(r"BFP:\s*(\d+\.\d+)%", clean_response)
|
| 164 |
-
bfp = float(bfp_match.group(1)) if bfp_match else 18.36 # Fallback
|
| 165 |
-
met_pred = predict_met(age, sex, weight_kg, bfp)
|
| 166 |
-
if met_pred is not None:
|
| 167 |
-
clean_response += f", Predicted MET: {met_pred}"
|
| 168 |
-
except Exception as e:
|
| 169 |
-
st.warning(f"MET prediction skipped: {str(e)}")
|
| 170 |
-
|
| 171 |
# Update history
|
| 172 |
-
st.session_state.history.append((prompt,
|
| 173 |
-
|
| 174 |
-
# Display clean response
|
| 175 |
-
output_container.write(clean_response)
|
| 176 |
-
return clean_response
|
| 177 |
|
| 178 |
except Exception as e:
|
| 179 |
st.error(f"Error during generation: {str(e)}")
|
|
@@ -241,4 +211,5 @@ with tab2:
|
|
| 241 |
# Clear history button
|
| 242 |
if st.button("Clear History"):
|
| 243 |
st.session_state.history = []
|
| 244 |
-
st.rerun()
|
|
|
|
|
|
| 4 |
import pandas as pd
|
| 5 |
import torch
|
| 6 |
import os
|
|
|
|
| 7 |
from huggingface_hub import login
|
| 8 |
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
|
| 9 |
from peft import PeftModel, PeftConfig
|
|
|
|
| 54 |
if 'history' not in st.session_state:
|
| 55 |
st.session_state.history = []
|
| 56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
# Prediction function
|
| 58 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 59 |
|
|
|
|
| 123 |
|
| 124 |
# Generate output
|
| 125 |
st.write("🤖 Model response:")
|
| 126 |
+
with st.empty():
|
| 127 |
+
text_streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False)
|
| 128 |
+
output = model.generate(
|
| 129 |
+
input_ids=input_ids,
|
| 130 |
+
attention_mask=attention_mask,
|
| 131 |
+
max_new_tokens=250,
|
| 132 |
+
temperature=0.7,
|
| 133 |
+
top_p=0.95,
|
| 134 |
+
do_sample=True,
|
| 135 |
+
pad_token_id=tokenizer.eos_token_id,
|
| 136 |
+
use_cache=True,
|
| 137 |
+
streamer=text_streamer
|
| 138 |
+
)
|
| 139 |
|
| 140 |
# Decode the output
|
| 141 |
decoded = tokenizer.decode(output[0], skip_special_tokens=False)
|
| 142 |
st.write(f"Decoded output: {decoded}")
|
| 143 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
# Update history
|
| 145 |
+
st.session_state.history.append((prompt, decoded))
|
| 146 |
+
return decoded
|
|
|
|
|
|
|
|
|
|
| 147 |
|
| 148 |
except Exception as e:
|
| 149 |
st.error(f"Error during generation: {str(e)}")
|
|
|
|
| 211 |
# Clear history button
|
| 212 |
if st.button("Clear History"):
|
| 213 |
st.session_state.history = []
|
| 214 |
+
st.rerun()
|
| 215 |
+
|