SallySims commited on
Commit
7a0a8fa
·
verified ·
1 Parent(s): 405f70a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -106
app.py CHANGED
@@ -4,7 +4,7 @@ import pandas as pd
4
  import torch
5
  import os
6
  from huggingface_hub import login
7
- from transformers import AutoTokenizer, AutoModelForCausalLM
8
  from peft import PeftModel, PeftConfig
9
  import io
10
 
@@ -28,6 +28,7 @@ def load_model():
28
 
29
  tokenizer = AutoTokenizer.from_pretrained(peft_config.base_model_name_or_path)
30
  tokenizer.pad_token = tokenizer.eos_token
 
31
 
32
  st.write("✅ Model and tokenizer loaded successfully.")
33
  return model, tokenizer
@@ -37,105 +38,17 @@ def load_model():
37
 
38
  model, tokenizer = load_model()
39
 
40
- # Calculate anthropometric metrics
41
- def calculate_metrics(age, sex, height_cm, weight_kg, wc_cm):
42
- # Convert height and WC to meters
43
- height_m = height_cm / 100
44
- wc_m = wc_cm / 100
45
-
46
- # BMI
47
- bmi = weight_kg / (height_m ** 2)
48
-
49
- # WHTR
50
- whtr = wc_m / height_m
51
-
52
- # BFP (Boer's formula approximation)
53
- if sex.lower() == "male":
54
- bfp = (1.20 * bmi) + (0.23 * age) - 16.2
55
- else:
56
- bfp = (1.20 * bmi) + (0.23 * age) - 5.4
57
-
58
- # LBM
59
- lbm = weight_kg * (1 - (bfp / 100))
60
-
61
- # BMI Category (WHO)
62
- if bmi < 18.5:
63
- bmi_category = "Underweight"
64
- elif bmi <= 24.9:
65
- bmi_category = "Normal"
66
- elif bmi <= 29.9:
67
- bmi_category = "Overweight"
68
- else:
69
- bmi_category = "Obese"
70
-
71
- # BFP Category (ACE)
72
- if sex.lower() == "female":
73
- if bfp <= 13:
74
- bfp_category = "Essential fat"
75
- elif bfp <= 20:
76
- bfp_category = "Athlete"
77
- elif bfp <= 24:
78
- bfp_category = "Fit"
79
- elif bfp <= 31:
80
- bfp_category = "Average"
81
- else:
82
- bfp_category = "Obese"
83
- else:
84
- if bfp <= 5:
85
- bfp_category = "Essential fat"
86
- elif bfp <= 13:
87
- bfp_category = "Athlete"
88
- elif bfp <= 17:
89
- bfp_category = "Fit"
90
- elif bfp <= 24:
91
- bfp_category = "Average"
92
- else:
93
- bfp_category = "Obese"
94
-
95
- # Interpretation
96
- if bfp_category=="Essential fat":
97
- interpretation ='Minimum fat required for basic physiological functions (e.g., hormone production, insulation). Females require higher essential fat due to reproductive functions.'
98
- elif bfp_category=='Athletes':
99
- interpretation='Typical for competitive athletes with high muscle mass and low fat (e.g., runners, bodybuilders).'
100
- elif bfp_category=='Fit':
101
- interpretation='Healthy range for active individuals who exercise regularly but aren’t competitive athletes.'
102
- elif bfp_category=='Average':
103
- interpretation='Common for the general population, still within healthy limits'
104
- elif bfp_category=='Obese':
105
- interpretation='Associated with increased health risks (e.g., diabetes, heart disease and other CVDs)'
106
-
107
-
108
- DataAll
109
- return {
110
- "BMI": round(bmi, 2),
111
- "WHTR": round(whtr, 2),
112
- "BFP": round(bfp, 2),
113
- "LBM": round(lbm, 2),
114
- "BMI_Category": bmi_category,
115
- "BFP_Category": bfp_category,
116
- "Interpretation": interpretation
117
- }
118
-
119
  # Prediction function
120
  device = "cuda" if torch.cuda.is_available() else "cpu"
121
 
122
  def get_prediction(age, sex, height_cm, weight_kg, wc_cm):
123
- # Calculate metrics
124
- metrics = calculate_metrics(age, sex, height_cm, weight_kg, wc_cm)
125
-
126
- # Create prompt with metrics
127
- prompt = (
128
- f"Age: {age}, Sex: {sex}, Height: {height_cm} cm, Weight: {weight_kg} kg, WC: {wc_cm} cm\n"
129
- f"BMI: {metrics['BMI']} kg/m2, WHTR: {metrics['WHTR']} m, BFP: {metrics['BFP']}%, "
130
- f"LBM: {metrics['LBM']} kg, BMI Category: {metrics['BMI_Category']}, "
131
- f"BFP Category: {metrics['BFP_Category']}\n"
132
- f"Interpretation: {metrics['Interpretation']}.\n###"
133
- )
134
  st.write(f"Received prompt: {prompt}")
135
 
136
  # Create message structure
137
  messages = [{"role": "user", "content": prompt}]
138
-
139
  # Tokenize the input
140
  try:
141
  inputs = tokenizer.apply_chat_template(
@@ -144,7 +57,8 @@ def get_prediction(age, sex, height_cm, weight_kg, wc_cm):
144
  add_generation_prompt=True,
145
  return_tensors="pt",
146
  max_length=512,
147
- truncation=True
 
148
  )
149
  except Exception as e:
150
  st.warning(f"apply_chat_template failed: {str(e)}. Falling back to manual tokenization.")
@@ -153,45 +67,53 @@ def get_prediction(age, sex, height_cm, weight_kg, wc_cm):
153
  return_tensors="pt",
154
  max_length=512,
155
  truncation=True,
156
- padding=False
 
157
  )
158
-
159
  # Debug: Log inputs structure
160
  st.write(f"Inputs type: {type(inputs)}")
161
-
162
  # Handle inputs (tensor or dict)
163
  if isinstance(inputs, torch.Tensor):
 
164
  input_ids = inputs
165
  if len(input_ids.shape) == 1:
166
- input_ids = input_ids.unsqueeze(0) # [sequence_length] -> [1, sequence_length]
167
- elif len(input_ids.shape) > 2:
168
- input_ids = input_ids.squeeze()
169
- if len(input_ids.shape) == 1:
170
- input_ids = input_ids.unsqueeze(0)
171
  elif isinstance(inputs, dict) and 'input_ids' in inputs:
172
  input_ids = inputs['input_ids']
 
173
  if len(input_ids.shape) == 3 and input_ids.shape[0] == 1:
174
  input_ids = input_ids.squeeze(0)
 
175
  elif len(input_ids.shape) == 1:
176
  input_ids = input_ids.unsqueeze(0)
 
177
  else:
178
  st.error(f"Unexpected inputs format: {type(inputs)}")
179
  return None
180
 
181
  st.write(f"Input IDs shape: {input_ids.shape}")
 
182
 
183
- # Ensure input_ids is on the correct device
184
  input_ids = input_ids.to(device)
 
185
 
186
  # Generate output
187
  try:
 
188
  output = model.generate(
189
  input_ids=input_ids,
190
- max_new_tokens=150,
 
191
  temperature=0.7,
192
  top_p=0.95,
193
  do_sample=True,
194
- pad_token_id=tokenizer.pad_token_id
 
 
195
  )
196
  except Exception as e:
197
  st.error(f"Error during generation: {str(e)}")
@@ -199,7 +121,7 @@ def get_prediction(age, sex, height_cm, weight_kg, wc_cm):
199
 
200
  # Decode the output
201
  try:
202
- decoded = tokenizer.decode(output[0], skip_special_tokens=False) # Preserve special tokens
203
  st.write(f"Decoded output: {decoded}")
204
  return decoded
205
  except Exception as e:
@@ -257,4 +179,5 @@ with tab2:
257
  st.dataframe(df)
258
 
259
  csv_output = df.to_csv(index=False).encode("utf-8")
260
- st.download_button("📤 Download Predictions", data=csv_output, file_name="predictions.csv")
 
 
4
  import torch
5
  import os
6
  from huggingface_hub import login
7
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
8
  from peft import PeftModel, PeftConfig
9
  import io
10
 
 
28
 
29
  tokenizer = AutoTokenizer.from_pretrained(peft_config.base_model_name_or_path)
30
  tokenizer.pad_token = tokenizer.eos_token
31
+ tokenizer.pad_token_id = tokenizer.eos_token_id # Explicitly set pad_token_id
32
 
33
  st.write("✅ Model and tokenizer loaded successfully.")
34
  return model, tokenizer
 
38
 
39
  model, tokenizer = load_model()
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  # Prediction function
42
  device = "cuda" if torch.cuda.is_available() else "cpu"
43
 
44
  def get_prediction(age, sex, height_cm, weight_kg, wc_cm):
45
+ # Create prompt matching test code
46
+ prompt = f"Age: {age}, Sex: {sex}, Height: {height_cm} cm, Weight: {weight_kg} kg, WC: {wc_cm} cm"
 
 
 
 
 
 
 
 
 
47
  st.write(f"Received prompt: {prompt}")
48
 
49
  # Create message structure
50
  messages = [{"role": "user", "content": prompt}]
51
+
52
  # Tokenize the input
53
  try:
54
  inputs = tokenizer.apply_chat_template(
 
57
  add_generation_prompt=True,
58
  return_tensors="pt",
59
  max_length=512,
60
+ truncation=True,
61
+ return_dict=True # Ensure dictionary output
62
  )
63
  except Exception as e:
64
  st.warning(f"apply_chat_template failed: {str(e)}. Falling back to manual tokenization.")
 
67
  return_tensors="pt",
68
  max_length=512,
69
  truncation=True,
70
+ padding=False,
71
+ return_attention_mask=True
72
  )
73
+
74
  # Debug: Log inputs structure
75
  st.write(f"Inputs type: {type(inputs)}")
76
+
77
  # Handle inputs (tensor or dict)
78
  if isinstance(inputs, torch.Tensor):
79
+ # Assume tensor is input_ids, create attention_mask
80
  input_ids = inputs
81
  if len(input_ids.shape) == 1:
82
+ input_ids = input_ids.unsqueeze(0)
83
+ attention_mask = torch.ones_like(input_ids) # Create attention_mask
 
 
 
84
  elif isinstance(inputs, dict) and 'input_ids' in inputs:
85
  input_ids = inputs['input_ids']
86
+ attention_mask = inputs.get('attention_mask', torch.ones_like(input_ids))
87
  if len(input_ids.shape) == 3 and input_ids.shape[0] == 1:
88
  input_ids = input_ids.squeeze(0)
89
+ attention_mask = attention_mask.squeeze(0)
90
  elif len(input_ids.shape) == 1:
91
  input_ids = input_ids.unsqueeze(0)
92
+ attention_mask = attention_mask.unsqueeze(0)
93
  else:
94
  st.error(f"Unexpected inputs format: {type(inputs)}")
95
  return None
96
 
97
  st.write(f"Input IDs shape: {input_ids.shape}")
98
+ st.write(f"Attention mask shape: {attention_mask.shape}")
99
 
100
+ # Move to device
101
  input_ids = input_ids.to(device)
102
+ attention_mask = attention_mask.to(device)
103
 
104
  # Generate output
105
  try:
106
+ text_streamer = TextStreamer(tokenizer)
107
  output = model.generate(
108
  input_ids=input_ids,
109
+ attention_mask=attention_mask,
110
+ max_new_tokens=250,
111
  temperature=0.7,
112
  top_p=0.95,
113
  do_sample=True,
114
+ pad_token_id=tokenizer.eos_token_id,
115
+ use_cache=True,
116
+ streamer=text_streamer
117
  )
118
  except Exception as e:
119
  st.error(f"Error during generation: {str(e)}")
 
121
 
122
  # Decode the output
123
  try:
124
+ decoded = tokenizer.decode(output[0], skip_special_tokens=False)
125
  st.write(f"Decoded output: {decoded}")
126
  return decoded
127
  except Exception as e:
 
179
  st.dataframe(df)
180
 
181
  csv_output = df.to_csv(index=False).encode("utf-8")
182
+ st.download_button("📤 Download Predictions", data=csv_output, file_name="predictions.csv")
183
+