SallySims commited on
Commit
e891368
Β·
verified Β·
1 Parent(s): 5d84d7a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -0
app.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Deploying on HuggingFace
2
+ import streamlit as st
3
+ import pandas as pd
4
+ import torch
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
6
+ from peft import PeftModel, PeftConfig
7
+ import io
8
+
9
+ st.set_page_config(page_title="AnthroBot", page_icon="πŸ€–", layout="centered")
10
+
11
+ # Load model & tokenizer
12
+ @st.cache_resource
13
+ def load_model():
14
+ peft_config = PeftConfig.from_pretrained("SallySims/AnthroBot_Model_Lora")
15
+ base_model = AutoModelForCausalLM.from_pretrained(
16
+ peft_config.base_model_name_or_path,
17
+ torch_dtype=torch.float16,
18
+ device_map="auto"
19
+ )
20
+ model = PeftModel.from_pretrained(base_model, "SallySims/AnthroBot_Model_Lora")
21
+ model.eval()
22
+
23
+ tokenizer = AutoTokenizer.from_pretrained(peft_config.base_model_name_or_path)
24
+ tokenizer.pad_token = tokenizer.eos_token
25
+ return model, tokenizer
26
+
27
+ model, tokenizer = load_model()
28
+
29
+ # Prediction function
30
+ def get_prediction(prompt):
31
+ messages = [{"role": "user", "content": prompt}]
32
+ inputs = tokenizer.apply_chat_template(
33
+ messages, tokenize=True, add_generation_prompt=True, return_tensors="pt"
34
+ ).to("cuda")
35
+ output = model.generate(inputs, max_new_tokens=150, temperature=0.7, top_p=0.95)
36
+ decoded = tokenizer.decode(output[0], skip_special_tokens=True)
37
+ return decoded.split("###")[-1].strip()
38
+
39
+ # UI Header
40
+ st.title("🧠 AnthroBot")
41
+ st.write("Enter your anthropometric estimates to receive an interpreted summary inputs β€” manually or via CSV upload.")
42
+
43
+ # Tabs for input method
44
+ tab1, tab2 = st.tabs(["🧍 Manual Input", "πŸ“„ CSV Upload"])
45
+
46
+ with tab1:
47
+ st.subheader("Manual Entry")
48
+ age = st.number_input("Age", 0, 100, 30)
49
+ sex = st.selectbox("Sex", ["male", "female"])
50
+ height = st.number_input("Height (cm)", 100.0, 250.0, 150.5)
51
+ weight = st.number_input("Weight (kg)", 30.0, 200.0, 75.3)
52
+ wc = st.number_input("Waist Circumference (cm)", 30.0, 150.0, 68.0)
53
+
54
+ if st.button("Get Prediction"):
55
+ prompt = f"Age: {age}, Sex: {sex}, Height: {height} cm, Weight: {weight} kg, WC: {wc} cm\n\n###"
56
+ prediction = get_prediction(prompt)
57
+ st.success("Prediction:")
58
+ st.write(prediction)
59
+
60
+ with tab2:
61
+ st.subheader("Batch Upload via CSV")
62
+ sample_csv = pd.DataFrame({
63
+ "Age": [30],
64
+ "Sex": ["female"],
65
+ "Height": [150.5],
66
+ "Weight": [75.3],
67
+ "WC": [68.0]
68
+ })
69
+
70
+ st.download_button("πŸ“₯ Download Sample CSV", sample_csv.to_csv(index=False), file_name="sample_input.csv")
71
+
72
+ uploaded_file = st.file_uploader("Upload your CSV file", type=["csv"])
73
+
74
+ if uploaded_file:
75
+ df = pd.read_csv(uploaded_file)
76
+ if not all(col in df.columns for col in ["Age", "Sex", "Height", "Weight", "WC"]):
77
+ st.error("CSV must contain columns: Age, Sex, Height, Weight, WC")
78
+ else:
79
+ outputs = []
80
+ with st.spinner("Generating predictions..."):
81
+ for _, row in df.iterrows():
82
+ prompt = (
83
+ f"Age: {row['Age']}, Sex: {row['Sex']}, Height: {row['Height']} cm, "
84
+ f"Weight: {row['Weight']} kg, WC: {row['WC']} cm\n\n###"
85
+ )
86
+ prediction = get_prediction(prompt)
87
+ outputs.append(prediction)
88
+
89
+ df["Prediction"] = outputs
90
+ st.success("Here are your predictions:")
91
+ st.dataframe(df)
92
+
93
+ csv_output = df.to_csv(index=False).encode("utf-8")
94
+ st.download_button("πŸ“€ Download Predictions", data=csv_output, file_name="predictions.csv")
95
+
96
+