SallySims commited on
Commit
6bc2197
·
verified ·
1 Parent(s): 7918f00

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -46
app.py CHANGED
@@ -4,7 +4,6 @@ import streamlit as st
4
  import pandas as pd
5
  import torch
6
  import os
7
- import re
8
  from huggingface_hub import login
9
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
10
  from peft import PeftModel, PeftConfig
@@ -55,18 +54,6 @@ model, tokenizer = load_model()
55
  if 'history' not in st.session_state:
56
  st.session_state.history = []
57
 
58
- # Placeholder MET prediction function (replace with actual model)
59
- def predict_met(age, sex, weight_kg, bfp):
60
- try:
61
- # Simulated MET prediction (replace with your model)
62
- # Example: Linear regression based on AntDatpromt.csv features
63
- lbm = weight_kg * (1 - (bfp / 100))
64
- met = 3.5 + 0.1 * lbm - 0.05 * age + (0.3 if sex == "male" else 0.0) # Placeholder formula
65
- return round(met, 2)
66
- except Exception as e:
67
- st.warning(f"MET prediction error: {str(e)}")
68
- return None
69
-
70
  # Prediction function
71
  device = "cuda" if torch.cuda.is_available() else "cpu"
72
 
@@ -136,44 +123,27 @@ def generate_response(age, sex, height_cm, weight_kg, wc_cm):
136
 
137
  # Generate output
138
  st.write("🤖 Model response:")
139
- output_container = st.empty()
140
- text_streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
141
- output = model.generate(
142
- input_ids=input_ids,
143
- attention_mask=attention_mask,
144
- max_new_tokens=250,
145
- temperature=0.7,
146
- top_p=0.95,
147
- do_sample=True,
148
- pad_token_id=tokenizer.eos_token_id,
149
- use_cache=True,
150
- streamer=text_streamer
151
- )
152
 
153
  # Decode the output
154
  decoded = tokenizer.decode(output[0], skip_special_tokens=False)
155
  st.write(f"Decoded output: {decoded}")
156
 
157
- # Extract assistant response (remove system header and tokens)
158
- assistant_response = re.search(r"<|start_header_id|>assistant<|end_header_id>\n(.*)<|eot_id>", decoded, re.DOTALL)
159
- clean_response = assistant_response.group(1).strip() if assistant_response else decoded
160
-
161
- # MET prediction
162
- try:
163
- bfp_match = re.search(r"BFP:\s*(\d+\.\d+)%", clean_response)
164
- bfp = float(bfp_match.group(1)) if bfp_match else 18.36 # Fallback
165
- met_pred = predict_met(age, sex, weight_kg, bfp)
166
- if met_pred is not None:
167
- clean_response += f", Predicted MET: {met_pred}"
168
- except Exception as e:
169
- st.warning(f"MET prediction skipped: {str(e)}")
170
-
171
  # Update history
172
- st.session_state.history.append((prompt, clean_response))
173
-
174
- # Display clean response
175
- output_container.write(clean_response)
176
- return clean_response
177
 
178
  except Exception as e:
179
  st.error(f"Error during generation: {str(e)}")
@@ -241,4 +211,5 @@ with tab2:
241
  # Clear history button
242
  if st.button("Clear History"):
243
  st.session_state.history = []
244
- st.rerun()
 
 
4
  import pandas as pd
5
  import torch
6
  import os
 
7
  from huggingface_hub import login
8
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
9
  from peft import PeftModel, PeftConfig
 
54
  if 'history' not in st.session_state:
55
  st.session_state.history = []
56
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  # Prediction function
58
  device = "cuda" if torch.cuda.is_available() else "cpu"
59
 
 
123
 
124
  # Generate output
125
  st.write("🤖 Model response:")
126
+ with st.empty():
127
+ text_streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False)
128
+ output = model.generate(
129
+ input_ids=input_ids,
130
+ attention_mask=attention_mask,
131
+ max_new_tokens=250,
132
+ temperature=0.7,
133
+ top_p=0.95,
134
+ do_sample=True,
135
+ pad_token_id=tokenizer.eos_token_id,
136
+ use_cache=True,
137
+ streamer=text_streamer
138
+ )
139
 
140
  # Decode the output
141
  decoded = tokenizer.decode(output[0], skip_special_tokens=False)
142
  st.write(f"Decoded output: {decoded}")
143
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  # Update history
145
+ st.session_state.history.append((prompt, decoded))
146
+ return decoded
 
 
 
147
 
148
  except Exception as e:
149
  st.error(f"Error during generation: {str(e)}")
 
211
  # Clear history button
212
  if st.button("Clear History"):
213
  st.session_state.history = []
214
+ st.rerun()
215
+