SallySims commited on
Commit
760ef09
Β·
verified Β·
1 Parent(s): 22ecc08

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +194 -41
app.py CHANGED
@@ -11,52 +11,205 @@ import io
11
  from transformers.tokenization_utils_base import BatchEncoding
12
 
13
  # Login using Hugging Face token
14
- login(token=os.getenv("HUGGINGFACEHUB_TOKEN"))
 
 
 
 
15
 
16
  st.set_page_config(page_title="AnthroBot", page_icon="πŸ€–", layout="centered")
17
 
18
- device = "cuda" if torch.cuda.is_available() else "cpu"
19
-
20
- # Load model and tokenizer
21
  @st.cache_resource
22
  def load_model():
23
- model = AutoModelForCausalLM.from_pretrained("SallySims/AnthroBot_Model_Lora").to(device)
24
- tokenizer = AutoTokenizer.from_pretrained("SallySims/AnthroBot_Model_Lora")
25
- return model, tokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  model, tokenizer = load_model()
28
 
29
- st.title("🧠 Health Metric Estimator")
30
- st.markdown("Enter your details below to get an AI-generated estimation.")
31
-
32
- # Input fields
33
- age = st.number_input("Age", min_value=1, max_value=120, value=30)
34
- sex = st.selectbox("Sex", options=["male", "female"])
35
- height = st.number_input("Height (cm)", min_value=50.0, max_value=250.0, value=170.0)
36
- weight = st.number_input("Weight (kg)", min_value=10.0, max_value=300.0, value=70.0)
37
- wc = st.number_input("Waist Circumference (cm)", min_value=20.0, max_value=200.0, value=80.0)
38
-
39
- if st.button("Estimate Metrics"):
40
- prompt = f"Age: {age}, Sex: {sex}, Height: {height} cm, Weight: {weight} kg, WC: {wc} cm"
41
- st.write(f"πŸ“ Prompt Sent to Model: `{prompt}`")
42
-
43
- messages = [{"role": "user", "content": prompt}]
44
-
45
- # Tokenize
46
- inputs = tokenizer.apply_chat_template(
47
- messages,
48
- tokenize=True,
49
- add_generation_prompt=True,
50
- return_tensors="pt"
51
- ).to("cuda")
52
-
53
- # Generate response with streaming
54
- st.write("πŸ€– Model response:")
55
- with st.empty():
56
- text_streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
57
- _ = model.generate(
58
- inputs,
59
- streamer=text_streamer,
60
- max_new_tokens=250,
61
- use_cache=True
62
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  from transformers.tokenization_utils_base import BatchEncoding
12
 
13
  # Login using Hugging Face token
14
+ try:
15
+ login(token=os.getenv("HUGGINGFACEHUB_TOKEN"))
16
+ except Exception as e:
17
+ st.error(f"Error logging in to Hugging Face: {str(e)}")
18
+ st.stop()
19
 
20
  st.set_page_config(page_title="AnthroBot", page_icon="πŸ€–", layout="centered")
21
 
22
+ # Load model & tokenizer
 
 
23
  @st.cache_resource
24
  def load_model():
25
+ try:
26
+ peft_config = PeftConfig.from_pretrained("SallySims/AnthroBot_Model_Lora")
27
+ base_model = AutoModelForCausalLM.from_pretrained(
28
+ peft_config.base_model_name_or_path,
29
+ torch_dtype=torch.float16,
30
+ device_map="auto",
31
+ trust_remote_code=True,
32
+ token=True
33
+ )
34
+ model = PeftModel.from_pretrained(base_model, "SallySims/AnthroBot_Model_Lora")
35
+ model.eval()
36
+
37
+ tokenizer = AutoTokenizer.from_pretrained(
38
+ peft_config.base_model_name_or_path,
39
+ trust_remote_code=True,
40
+ token=True
41
+ )
42
+ tokenizer.pad_token = tokenizer.eos_token
43
+ tokenizer.pad_token_id = tokenizer.eos_token_id # Set pad_token_id to eos_token_id (128001)
44
+
45
+ st.write("βœ… Model and tokenizer loaded successfully.")
46
+ return model, tokenizer
47
+ except Exception as e:
48
+ st.error(f"Error loading model: {str(e)}")
49
+ raise e
50
 
51
  model, tokenizer = load_model()
52
 
53
+ # Initialize session state for prediction history
54
+ if 'history' not in st.session_state:
55
+ st.session_state.history = []
56
+
57
+ # Prediction function
58
+ device = "cuda" if torch.cuda.is_available() else "cpu"
59
+
60
+ def generate_response(age, sex, height_cm, weight_kg, wc_cm):
61
+ try:
62
+ # Create prompt
63
+ prompt = f"Age: {age}, Sex: {sex}, Height: {height_cm} cm, Weight: {weight_kg} kg, WC: {wc_cm} cm"
64
+ st.write(f"πŸ“ Prompt Sent to Model: `{prompt}`")
65
+
66
+ # Create message structure
67
+ messages = [{"role": "user", "content": prompt}]
68
+
69
+ # Tokenize the input
70
+ try:
71
+ inputs = tokenizer.apply_chat_template(
72
+ messages,
73
+ tokenize=True,
74
+ add_generation_prompt=True,
75
+ return_tensors="pt",
76
+ max_length=512,
77
+ truncation=True,
78
+ return_dict=True
79
+ )
80
+ except Exception as e:
81
+ st.warning(f"apply_chat_template failed: {str(e)}. Falling back to manual tokenization.")
82
+ inputs = tokenizer(
83
+ prompt,
84
+ return_tensors="pt",
85
+ max_length=512,
86
+ truncation=True,
87
+ padding=False,
88
+ return_attention_mask=True
89
+ )
90
+
91
+ # Debug: Log inputs structure
92
+ st.write(f"Inputs type: {type(inputs)}")
93
+ st.write(f"Inputs keys: {list(inputs.keys()) if isinstance(inputs, (dict, BatchEncoding)) else 'N/A'}")
94
+
95
+ # Handle inputs
96
+ if isinstance(inputs, (dict, BatchEncoding)):
97
+ input_ids = inputs['input_ids']
98
+ attention_mask = inputs.get('attention_mask', torch.ones_like(input_ids))
99
+ elif isinstance(inputs, torch.Tensor):
100
+ input_ids = inputs
101
+ attention_mask = torch.ones_like(input_ids)
102
+ else:
103
+ st.error(f"Unexpected inputs format: {type(inputs)}")
104
+ return None
105
+
106
+ # Ensure 2D tensors
107
+ if len(input_ids.shape) == 1:
108
+ input_ids = input_ids.unsqueeze(0)
109
+ attention_mask = attention_mask.unsqueeze(0)
110
+ elif len(input_ids.shape) > 2:
111
+ input_ids = input_ids.squeeze()
112
+ attention_mask = attention_mask.squeeze()
113
+ if len(input_ids.shape) == 1:
114
+ input_ids = input_ids.unsqueeze(0)
115
+ attention_mask = attention_mask.unsqueeze(0)
116
+
117
+ st.write(f"Input IDs shape: {input_ids.shape}")
118
+ st.write(f"Attention mask shape: {attention_mask.shape}")
119
+
120
+ # Move to device
121
+ input_ids = input_ids.to(device)
122
+ attention_mask = attention_mask.to(device)
123
+
124
+ # Generate output
125
+ st.write("πŸ€– Model response:")
126
+ with st.empty():
127
+ text_streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False)
128
+ output = model.generate(
129
+ input_ids=input_ids,
130
+ attention_mask=attention_mask,
131
+ max_new_tokens=250,
132
+ temperature=0.7,
133
+ top_p=0.95,
134
+ do_sample=True,
135
+ pad_token_id=tokenizer.eos_token_id,
136
+ use_cache=True,
137
+ streamer=text_streamer
138
+ )
139
+
140
+ # Decode the output
141
+ decoded = tokenizer.decode(output[0], skip_special_tokens=False)
142
+ st.write(f"Decoded output: {decoded}")
143
+
144
+ # Update history
145
+ st.session_state.history.append((prompt, decoded))
146
+ return decoded
147
+
148
+ except Exception as e:
149
+ st.error(f"Error during generation: {str(e)}")
150
+ return None
151
+
152
+ # UI Header
153
+ st.title("🧠 AnthroBot")
154
+ st.markdown("Enter your anthropometric details to receive an AI-generated summary of health metrics.")
155
+
156
+ # Tabs for input method
157
+ tab1, tab2 = st.tabs(["🧍 Manual Input", "πŸ“„ CSV Upload"])
158
+
159
+ with tab1:
160
+ st.subheader("Manual Entry")
161
+ age = st.number_input("Age", min_value=1, max_value=120, value=30)
162
+ sex = st.selectbox("Sex", options=["male", "female"])
163
+ height = st.number_input("Height (cm)", min_value=50.0, max_value=250.0, value=170.0)
164
+ weight = st.number_input("Weight (kg)", min_value=10.0, max_value=300.0, value=70.0)
165
+ wc = st.number_input("Waist Circumference (cm)", min_value=20.0, max_value=200.0, value=80.0)
166
+
167
+ if st.button("Estimate Metrics"):
168
+ prediction = generate_response(age, sex, height, weight, wc)
169
+ if prediction:
170
+ st.success("Prediction:")
171
+ st.write(prediction)
172
+
173
+ # Display history
174
+ st.subheader("Prediction History")
175
+ for prompt, response in st.session_state.history:
176
+ st.markdown(f"**Input**: {prompt}")
177
+ st.markdown(f"**Output**: {response}")
178
+
179
+ with tab2:
180
+ st.subheader("Batch Upload via CSV")
181
+ sample_csv = pd.DataFrame({
182
+ "Age": [30],
183
+ "Sex": ["male"],
184
+ "Height": [170.0],
185
+ "Weight": [70.0],
186
+ "WC": [80.0]
187
+ })
188
+
189
+ st.download_button("πŸ“₯ Download Sample CSV", sample_csv.to_csv(index=False), file_name="sample_input.csv")
190
+
191
+ uploaded_file = st.file_uploader("Upload your CSV file", type=["csv"])
192
+
193
+ if uploaded_file:
194
+ df = pd.read_csv(uploaded_file)
195
+ if not all(col in df.columns for col in ["Age", "Sex", "Height", "Weight", "WC"]):
196
+ st.error("CSV must contain columns: Age, Sex, Height, Weight, WC")
197
+ else:
198
+ outputs = []
199
+ with st.spinner("Generating predictions..."):
200
+ for _, row in df.iterrows():
201
+ prediction = generate_response(row['Age'], row['Sex'], row['Height'], row['Weight'], row['WC'])
202
+ outputs.append(prediction if prediction else "Error")
203
+
204
+ df["Prediction"] = outputs
205
+ st.success("Here are your predictions:")
206
+ st.dataframe(df)
207
+
208
+ csv_output = df.to_csv(index=False).encode("utf-8")
209
+ st.download_button("πŸ“€ Download Predictions", data=csv_output, file_name="predictions.csv")
210
+
211
+ # Clear history button
212
+ if st.button("Clear History"):
213
+ st.session_state.history = []
214
+ st.rerun()
215
+