Update app.py
Browse files
app.py
CHANGED
|
@@ -4,7 +4,7 @@ import streamlit as st
|
|
| 4 |
import faiss
|
| 5 |
import numpy as np
|
| 6 |
import torch
|
| 7 |
-
from transformers import
|
| 8 |
from sentence_transformers import SentenceTransformer
|
| 9 |
from reportlab.lib.pagesizes import A4
|
| 10 |
from reportlab.platypus import Paragraph, SimpleDocTemplate, Spacer
|
|
@@ -45,12 +45,20 @@ def retrieve_milestone(user_input):
|
|
| 45 |
_, indices = index.search(user_embedding, 1)
|
| 46 |
return descriptions[indices[0][0]] if indices[0][0] < len(descriptions) else "No relevant milestone found."
|
| 47 |
|
| 48 |
-
# Load IBM Granite 3.1 model and tokenizer
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
def generate_response(user_input, child_age):
|
| 56 |
relevant_milestone = retrieve_milestone(user_input)
|
|
@@ -61,31 +69,22 @@ def generate_response(user_input, child_age):
|
|
| 61 |
"If there are any concerns, suggest steps the parents can take."
|
| 62 |
)
|
| 63 |
|
| 64 |
-
inputs = tokenizer(prompt, return_tensors="pt").to(
|
| 65 |
output = granite_model.generate(**inputs, max_length=512)
|
| 66 |
return tokenizer.decode(output[0], skip_special_tokens=True)
|
| 67 |
|
| 68 |
# Streamlit UI Styling
|
| 69 |
st.set_page_config(page_title="Tiny Triumphs Tracker", page_icon="👶", layout="wide")
|
| 70 |
|
| 71 |
-
st.markdown(""
|
| 72 |
-
|
| 73 |
-
.stApp { background-color: #1e1e2e; color: #ffffff; }
|
| 74 |
-
.stTitle { text-align: center; color: #ffcc00; font-size: 36px; font-weight: bold; }
|
| 75 |
-
.stButton > button { background-color: #ffcc00; color: #000; border-radius: 5px; font-weight: bold; }
|
| 76 |
-
.stSelectbox, .stTextArea { background-color: #2e2e42; color: #ffffff; border-radius: 5px; }
|
| 77 |
-
</style>
|
| 78 |
-
""", unsafe_allow_html=True)
|
| 79 |
-
|
| 80 |
-
st.markdown("<h1 class='stTitle'>👶 Tiny Triumphs Tracker</h1>", unsafe_allow_html=True)
|
| 81 |
-
st.markdown("Track your child's key growth milestones from birth to 5 years and detect early developmental concerns.", unsafe_allow_html=True)
|
| 82 |
|
| 83 |
# User selects child's age
|
| 84 |
selected_age = st.selectbox("📅 Select child's age:", list(age_categories.keys()))
|
| 85 |
child_age = age_categories[selected_age]
|
| 86 |
|
| 87 |
# User input for traits and skills
|
| 88 |
-
placeholder_text = "For example, your child might say simple words like 'mama' and 'dada' and smile when spoken to.
|
| 89 |
user_input = st.text_area("✍️ Enter child's behavioral traits and skills:", placeholder=placeholder_text)
|
| 90 |
|
| 91 |
def generate_pdf_report(ai_response):
|
|
@@ -125,4 +124,4 @@ if st.button("🔍 Analyze", help="Click to analyze the child's development mile
|
|
| 125 |
with open(pdf_file, "rb") as f:
|
| 126 |
st.download_button(label="📥 Download Progress Report", data=f, file_name="progress_report.pdf", mime="application/pdf")
|
| 127 |
|
| 128 |
-
st.warning("⚠️ The results provided are generated by AI and should be interpreted with caution. Please consult a pediatrician for professional advice.")
|
|
|
|
| 4 |
import faiss
|
| 5 |
import numpy as np
|
| 6 |
import torch
|
| 7 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 8 |
from sentence_transformers import SentenceTransformer
|
| 9 |
from reportlab.lib.pagesizes import A4
|
| 10 |
from reportlab.platypus import Paragraph, SimpleDocTemplate, Spacer
|
|
|
|
| 45 |
_, indices = index.search(user_embedding, 1)
|
| 46 |
return descriptions[indices[0][0]] if indices[0][0] < len(descriptions) else "No relevant milestone found."
|
| 47 |
|
| 48 |
+
# Load IBM Granite 3.1 model and tokenizer from Hugging Face
|
| 49 |
+
MODEL_NAME = "ibm-granite/granite-3.1-8b-instruct"
|
| 50 |
+
|
| 51 |
+
@st.cache_resource # Cache model to avoid reloading on every interaction
|
| 52 |
+
def load_model():
|
| 53 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
| 54 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 55 |
+
MODEL_NAME,
|
| 56 |
+
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
| 57 |
+
device_map="auto" # Auto-select GPU/CPU
|
| 58 |
+
)
|
| 59 |
+
return tokenizer, model
|
| 60 |
+
|
| 61 |
+
tokenizer, granite_model = load_model()
|
| 62 |
|
| 63 |
def generate_response(user_input, child_age):
|
| 64 |
relevant_milestone = retrieve_milestone(user_input)
|
|
|
|
| 69 |
"If there are any concerns, suggest steps the parents can take."
|
| 70 |
)
|
| 71 |
|
| 72 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(granite_model.device)
|
| 73 |
output = granite_model.generate(**inputs, max_length=512)
|
| 74 |
return tokenizer.decode(output[0], skip_special_tokens=True)
|
| 75 |
|
| 76 |
# Streamlit UI Styling
|
| 77 |
st.set_page_config(page_title="Tiny Triumphs Tracker", page_icon="👶", layout="wide")
|
| 78 |
|
| 79 |
+
st.markdown("<h1 style='text-align:center; color:#ffcc00;'>👶 Tiny Triumphs Tracker</h1>", unsafe_allow_html=True)
|
| 80 |
+
st.markdown("Track your child's key growth milestones from birth to 5 years and detect early developmental concerns.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
|
| 82 |
# User selects child's age
|
| 83 |
selected_age = st.selectbox("📅 Select child's age:", list(age_categories.keys()))
|
| 84 |
child_age = age_categories[selected_age]
|
| 85 |
|
| 86 |
# User input for traits and skills
|
| 87 |
+
placeholder_text = "For example, your child might say simple words like 'mama' and 'dada' and smile when spoken to."
|
| 88 |
user_input = st.text_area("✍️ Enter child's behavioral traits and skills:", placeholder=placeholder_text)
|
| 89 |
|
| 90 |
def generate_pdf_report(ai_response):
|
|
|
|
| 124 |
with open(pdf_file, "rb") as f:
|
| 125 |
st.download_button(label="📥 Download Progress Report", data=f, file_name="progress_report.pdf", mime="application/pdf")
|
| 126 |
|
| 127 |
+
st.warning("⚠️ The results provided are generated by AI and should be interpreted with caution. Please consult a pediatrician for professional advice.")
|