Update app.py
Browse files
app.py
CHANGED
|
@@ -33,15 +33,18 @@ def load_model():
|
|
| 33 |
model, tokenizer = load_model()
|
| 34 |
|
| 35 |
# Prediction function
|
|
|
|
|
|
|
| 36 |
def get_prediction(prompt):
|
| 37 |
messages = [{"role": "user", "content": prompt}]
|
| 38 |
inputs = tokenizer.apply_chat_template(
|
| 39 |
messages, tokenize=True, add_generation_prompt=True, return_tensors="pt"
|
| 40 |
-
).to(
|
| 41 |
output = model.generate(inputs, max_new_tokens=150, temperature=0.7, top_p=0.95)
|
| 42 |
decoded = tokenizer.decode(output[0], skip_special_tokens=True)
|
| 43 |
return decoded.split("###")[-1].strip()
|
| 44 |
|
|
|
|
| 45 |
# UI Header
|
| 46 |
st.title("🧠 AnthroBot")
|
| 47 |
st.write("Enter your anthropometric estimates to receive an interpreted summary inputs — manually or via CSV upload.")
|
|
|
|
| 33 |
model, tokenizer = load_model()
|
| 34 |
|
| 35 |
# Prediction function
|
| 36 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 37 |
+
|
| 38 |
def get_prediction(prompt):
|
| 39 |
messages = [{"role": "user", "content": prompt}]
|
| 40 |
inputs = tokenizer.apply_chat_template(
|
| 41 |
messages, tokenize=True, add_generation_prompt=True, return_tensors="pt"
|
| 42 |
+
).to(device)
|
| 43 |
output = model.generate(inputs, max_new_tokens=150, temperature=0.7, top_p=0.95)
|
| 44 |
decoded = tokenizer.decode(output[0], skip_special_tokens=True)
|
| 45 |
return decoded.split("###")[-1].strip()
|
| 46 |
|
| 47 |
+
|
| 48 |
# UI Header
|
| 49 |
st.title("🧠 AnthroBot")
|
| 50 |
st.write("Enter your anthropometric estimates to receive an interpreted summary inputs — manually or via CSV upload.")
|