SallySims commited on
Commit
07fb4b6
·
verified ·
1 Parent(s): 62d3ad2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -159
app.py CHANGED
@@ -15,172 +15,46 @@ login(token=os.getenv("HUGGINGFACEHUB_TOKEN"))
15
 
16
  st.set_page_config(page_title="AnthroBot", page_icon="🤖", layout="centered")
17
 
18
- # Load model & tokenizer
19
  @st.cache_resource
20
  def load_model():
21
- try:
22
- peft_config = PeftConfig.from_pretrained("SallySims/AnthroBot_Model_Lora")
23
- base_model = AutoModelForCausalLM.from_pretrained(
24
- peft_config.base_model_name_or_path,
25
- torch_dtype=torch.float16,
26
- device_map="auto"
27
- )
28
- model = PeftModel.from_pretrained(base_model, "SallySims/AnthroBot_Model_Lora")
29
- model.eval()
30
-
31
- tokenizer = AutoTokenizer.from_pretrained(peft_config.base_model_name_or_path)
32
- tokenizer.pad_token = tokenizer.eos_token
33
- tokenizer.pad_token_id = tokenizer.eos_token_id # Explicitly set pad_token_id
34
-
35
- st.write("✅ Model and tokenizer loaded successfully.")
36
- return model, tokenizer
37
- except Exception as e:
38
- st.error(f"Error loading model: {str(e)}")
39
- raise e
40
 
41
  model, tokenizer = load_model()
42
 
43
- # Prediction function
44
- device = "cuda" if torch.cuda.is_available() else "cpu"
45
-
46
- def get_prediction(age, sex, height_cm, weight_kg, wc_cm):
47
- # Create prompt matching test code
48
- prompt = f"Age: {age}, Sex: {sex}, Height: {height_cm} cm, Weight: {weight_kg} kg, WC: {wc_cm} cm"
49
- st.write(f"Received prompt: {prompt}")
50
-
51
- # Create message structure
52
- messages = [{"role": "user", "content": prompt}]
53
-
54
- # Tokenize the input
55
- try:
56
- inputs = tokenizer.apply_chat_template(
57
- messages,
58
- tokenize=True,
59
- add_generation_prompt=True,
60
- return_tensors="pt",
61
- max_length=512,
62
- truncation=True,
63
- return_dict=True
64
- )
65
- except Exception as e:
66
- st.warning(f"apply_chat_template failed: {str(e)}. Falling back to manual tokenization.")
67
- inputs = tokenizer(
68
- prompt,
69
- return_tensors="pt",
70
- max_length=512,
71
- truncation=True,
72
- padding=False,
73
- return_attention_mask=True
74
- )
75
-
76
- # Debug: Log inputs structure
77
- st.write(f"Inputs type: {type(inputs)}")
78
- st.write(f"Inputs content: {inputs}")
79
 
80
- # Handle inputs (tensor, dict, or BatchEncoding)
81
- if isinstance(inputs, torch.Tensor):
82
- input_ids = inputs
83
- attention_mask = torch.ones_like(input_ids)
84
- if len(input_ids.shape) == 1:
85
- input_ids = input_ids.unsqueeze(0)
86
- attention_mask = attention_mask.unsqueeze(0)
87
- elif isinstance(inputs, (dict, BatchEncoding)):
88
- input_ids = inputs['input_ids']
89
- attention_mask = inputs.get('attention_mask', torch.ones_like(input_ids))
90
- if len(input_ids.shape) == 3 and input_ids.shape[0] == 1:
91
- input_ids = input_ids.squeeze(0)
92
- attention_mask = attention_mask.squeeze(0)
93
- elif len(input_ids.shape) == 1:
94
- input_ids = input_ids.unsqueeze(0)
95
- attention_mask = attention_mask.unsqueeze(0)
96
- else:
97
- st.error(f"Unexpected inputs format: {type(inputs)}")
98
- return None
99
 
100
- st.write(f"Input IDs shape: {input_ids.shape}")
101
- st.write(f"Attention mask shape: {attention_mask.shape}")
 
102
 
103
- # Move to device
104
- input_ids = input_ids.to(device)
105
- attention_mask = attention_mask.to(device)
106
 
107
- # Generate output
108
- try:
109
- text_streamer = TextStreamer(tokenizer)
110
- output = model.generate(
111
- input_ids=input_ids,
112
- attention_mask=attention_mask,
 
 
 
 
 
 
 
 
 
113
  max_new_tokens=250,
114
- temperature=0.7,
115
- top_p=0.95,
116
- do_sample=True,
117
- pad_token_id=tokenizer.eos_token_id,
118
- use_cache=True,
119
- streamer=text_streamer
120
- )
121
- except Exception as e:
122
- st.error(f"Error during generation: {str(e)}")
123
- return None
124
-
125
- # Decode the output
126
- try:
127
- decoded = tokenizer.decode(output[0], skip_special_tokens=False)
128
- st.write(f"Decoded output: {decoded}")
129
- return decoded
130
- except Exception as e:
131
- st.error(f"Error decoding output: {str(e)}")
132
- return None
133
-
134
- # UI Header
135
- st.title("🧠 AnthroBot")
136
- st.write("Enter your anthropometric estimates to receive an interpreted summary — manually or via CSV upload.")
137
-
138
- # Tabs for input method
139
- tab1, tab2 = st.tabs(["🧍 Manual Input", "📄 CSV Upload"])
140
-
141
- with tab1:
142
- st.subheader("Manual Entry")
143
- age = st.number_input("Age", 0, 100, 16)
144
- sex = st.selectbox("Sex", ["male", "female"], index=1)
145
- height = st.number_input("Height (cm)", 100.0, 250.0, 153.0)
146
- weight = st.number_input("Weight (kg)", 30.0, 200.0, 51.1)
147
- wc = st.number_input("Waist Circumference (cm)", 30.0, 150.0, 64.0)
148
-
149
- if st.button("Get Prediction"):
150
- prediction = get_prediction(age, sex, height, weight, wc)
151
- if prediction:
152
- st.success("Prediction:")
153
- st.write(prediction)
154
-
155
- with tab2:
156
- st.subheader("Batch Upload via CSV")
157
- sample_csv = pd.DataFrame({
158
- "Age": [16],
159
- "Sex": ["female"],
160
- "Height": [153.0],
161
- "Weight": [51.1],
162
- "WC": [64.0]
163
- })
164
-
165
- st.download_button("📥 Download Sample CSV", sample_csv.to_csv(index=False), file_name="sample_input.csv")
166
-
167
- uploaded_file = st.file_uploader("Upload your CSV file", type=["csv"])
168
-
169
- if uploaded_file:
170
- df = pd.read_csv(uploaded_file)
171
- if not all(col in df.columns for col in ["Age", "Sex", "Height", "Weight", "WC"]):
172
- st.error("CSV must contain columns: Age, Sex, Height, Weight, WC")
173
- else:
174
- outputs = []
175
- with st.spinner("Generating predictions..."):
176
- for _, row in df.iterrows():
177
- prediction = get_prediction(row['Age'], row['Sex'], row['Height'], row['Weight'], row['WC'])
178
- outputs.append(prediction if prediction else "Error")
179
-
180
- df["Prediction"] = outputs
181
- st.success("Here are your predictions:")
182
- st.dataframe(df)
183
-
184
- csv_output = df.to_csv(index=False).encode("utf-8")
185
- st.download_button("📤 Download Predictions", data=csv_output, file_name="predictions.csv")
186
-
 
15
 
16
  st.set_page_config(page_title="AnthroBot", page_icon="🤖", layout="centered")
17
 
18
+ # Load model and tokenizer
19
  @st.cache_resource
20
  def load_model():
21
+ model = AutoModelForCausalLM.from_pretrained("SallySims/AnthroBot_Model_Lora").to("cuda")
22
+ tokenizer = AutoTokenizer.from_pretrained("SallySims/AnthroBot_Model_Lora")
23
+ return model, tokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  model, tokenizer = load_model()
26
 
27
+ st.title("🧠 Health Metric Estimator")
28
+ st.markdown("Enter your details below to get an AI-generated estimation.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
+ # Input fields
31
+ age = st.number_input("Age", min_value=1, max_value=120, value=30)
32
+ sex = st.selectbox("Sex", options=["male", "female"])
33
+ height = st.number_input("Height (cm)", min_value=50.0, max_value=250.0, value=170.0)
34
+ weight = st.number_input("Weight (kg)", min_value=10.0, max_value=300.0, value=70.0)
35
+ wc = st.number_input("Waist Circumference (cm)", min_value=20.0, max_value=200.0, value=80.0)
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
+ if st.button("Estimate Metrics"):
38
+ prompt = f"Age: {age}, Sex: {sex}, Height: {height} cm, Weight: {weight} kg, WC: {wc} cm"
39
+ st.write(f"📝 Prompt Sent to Model: `{prompt}`")
40
 
41
+ messages = [{"role": "user", "content": prompt}]
 
 
42
 
43
+ # Tokenize
44
+ inputs = tokenizer.apply_chat_template(
45
+ messages,
46
+ tokenize=True,
47
+ add_generation_prompt=True,
48
+ return_tensors="pt"
49
+ ).to("cuda")
50
+
51
+ # Generate response with streaming
52
+ st.write("🤖 Model response:")
53
+ with st.empty():
54
+ text_streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
55
+ _ = model.generate(
56
+ inputs,
57
+ streamer=text_streamer,
58
  max_new_tokens=250,
59
+ use_cache=True
60
+ )