safiaa02 commited on
Commit
0b8937c
·
verified ·
1 Parent(s): 65f404d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -21
app.py CHANGED
@@ -4,7 +4,7 @@ import streamlit as st
4
  import faiss
5
  import numpy as np
6
  import torch
7
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
8
  from sentence_transformers import SentenceTransformer
9
  from reportlab.lib.pagesizes import A4
10
  from reportlab.platypus import Paragraph, SimpleDocTemplate, Spacer
@@ -45,12 +45,20 @@ def retrieve_milestone(user_input):
45
  _, indices = index.search(user_embedding, 1)
46
  return descriptions[indices[0][0]] if indices[0][0] < len(descriptions) else "No relevant milestone found."
47
 
48
- # Load IBM Granite 3.1 model and tokenizer
49
- model_name = "ibm-granite/granite-3.1-8b-instruct"
50
- tokenizer = AutoTokenizer.from_pretrained(model_name)
51
- granite_model = AutoModelForSeq2SeqLM.from_pretrained(
52
- model_name, torch_dtype=torch.float16, device_map="auto"
53
- )
 
 
 
 
 
 
 
 
54
 
55
  def generate_response(user_input, child_age):
56
  relevant_milestone = retrieve_milestone(user_input)
@@ -61,31 +69,22 @@ def generate_response(user_input, child_age):
61
  "If there are any concerns, suggest steps the parents can take."
62
  )
63
 
64
- inputs = tokenizer(prompt, return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
65
  output = granite_model.generate(**inputs, max_length=512)
66
  return tokenizer.decode(output[0], skip_special_tokens=True)
67
 
68
  # Streamlit UI Styling
69
  st.set_page_config(page_title="Tiny Triumphs Tracker", page_icon="👶", layout="wide")
70
 
71
- st.markdown("""
72
- <style>
73
- .stApp { background-color: #1e1e2e; color: #ffffff; }
74
- .stTitle { text-align: center; color: #ffcc00; font-size: 36px; font-weight: bold; }
75
- .stButton > button { background-color: #ffcc00; color: #000; border-radius: 5px; font-weight: bold; }
76
- .stSelectbox, .stTextArea { background-color: #2e2e42; color: #ffffff; border-radius: 5px; }
77
- </style>
78
- """, unsafe_allow_html=True)
79
-
80
- st.markdown("<h1 class='stTitle'>👶 Tiny Triumphs Tracker</h1>", unsafe_allow_html=True)
81
- st.markdown("Track your child's key growth milestones from birth to 5 years and detect early developmental concerns.", unsafe_allow_html=True)
82
 
83
  # User selects child's age
84
  selected_age = st.selectbox("📅 Select child's age:", list(age_categories.keys()))
85
  child_age = age_categories[selected_age]
86
 
87
  # User input for traits and skills
88
- placeholder_text = "For example, your child might say simple words like 'mama' and 'dada' and smile when spoken to. They may grasp small objects with their fingers and show excitement during playtime."
89
  user_input = st.text_area("✍️ Enter child's behavioral traits and skills:", placeholder=placeholder_text)
90
 
91
  def generate_pdf_report(ai_response):
@@ -125,4 +124,4 @@ if st.button("🔍 Analyze", help="Click to analyze the child's development mile
125
  with open(pdf_file, "rb") as f:
126
  st.download_button(label="📥 Download Progress Report", data=f, file_name="progress_report.pdf", mime="application/pdf")
127
 
128
- st.warning("⚠️ The results provided are generated by AI and should be interpreted with caution. Please consult a pediatrician for professional advice.")
 
4
  import faiss
5
  import numpy as np
6
  import torch
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer
8
  from sentence_transformers import SentenceTransformer
9
  from reportlab.lib.pagesizes import A4
10
  from reportlab.platypus import Paragraph, SimpleDocTemplate, Spacer
 
45
  _, indices = index.search(user_embedding, 1)
46
  return descriptions[indices[0][0]] if indices[0][0] < len(descriptions) else "No relevant milestone found."
47
 
48
+ # Load IBM Granite 3.1 model and tokenizer from Hugging Face
49
+ MODEL_NAME = "ibm-granite/granite-3.1-8b-instruct"
50
+
51
+ @st.cache_resource # Cache model to avoid reloading on every interaction
52
+ def load_model():
53
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
54
+ model = AutoModelForCausalLM.from_pretrained(
55
+ MODEL_NAME,
56
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
57
+ device_map="auto" # Auto-select GPU/CPU
58
+ )
59
+ return tokenizer, model
60
+
61
+ tokenizer, granite_model = load_model()
62
 
63
  def generate_response(user_input, child_age):
64
  relevant_milestone = retrieve_milestone(user_input)
 
69
  "If there are any concerns, suggest steps the parents can take."
70
  )
71
 
72
+ inputs = tokenizer(prompt, return_tensors="pt").to(granite_model.device)
73
  output = granite_model.generate(**inputs, max_length=512)
74
  return tokenizer.decode(output[0], skip_special_tokens=True)
75
 
76
  # Streamlit UI Styling
77
  st.set_page_config(page_title="Tiny Triumphs Tracker", page_icon="👶", layout="wide")
78
 
79
+ st.markdown("<h1 style='text-align:center; color:#ffcc00;'>👶 Tiny Triumphs Tracker</h1>", unsafe_allow_html=True)
80
+ st.markdown("Track your child's key growth milestones from birth to 5 years and detect early developmental concerns.")
 
 
 
 
 
 
 
 
 
81
 
82
  # User selects child's age
83
  selected_age = st.selectbox("📅 Select child's age:", list(age_categories.keys()))
84
  child_age = age_categories[selected_age]
85
 
86
  # User input for traits and skills
87
+ placeholder_text = "For example, your child might say simple words like 'mama' and 'dada' and smile when spoken to."
88
  user_input = st.text_area("✍️ Enter child's behavioral traits and skills:", placeholder=placeholder_text)
89
 
90
  def generate_pdf_report(ai_response):
 
124
  with open(pdf_file, "rb") as f:
125
  st.download_button(label="📥 Download Progress Report", data=f, file_name="progress_report.pdf", mime="application/pdf")
126
 
127
+ st.warning("⚠️ The results provided are generated by AI and should be interpreted with caution. Please consult a pediatrician for professional advice.")