Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app.py
|
| 2 |
+
from langchain_groq import ChatGroq
|
| 3 |
+
from langchain.vectorstores import Chroma
|
| 4 |
+
from langchain.document_loaders import PyPDFLoader, DirectoryLoader
|
| 5 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 6 |
+
from langchain.embeddings import HuggingFaceBgeEmbeddings
|
| 7 |
+
from transformers import pipeline
|
| 8 |
+
import gradio as gr
|
| 9 |
+
import os
|
| 10 |
+
|
| 11 |
+
# Initialize Groq with environment variable
|
| 12 |
+
llm = ChatGroq(
|
| 13 |
+
temperature=0.7,
|
| 14 |
+
groq_api_key=os.environ.get("Groq_API_KEY"), # Set in HF Secrets
|
| 15 |
+
model_name="llama-3.3-70b-versatile"
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
# Configure paths for Hugging Face Space
|
| 19 |
+
VECTOR_DB_PATH = "./chroma_db"
|
| 20 |
+
PDF_DIR = "./Pregnancy"
|
| 21 |
+
MODEL_PATH = "./indian_food_model"
|
| 22 |
+
|
| 23 |
+
# Initialize or load vector database
|
| 24 |
+
if not os.path.exists(VECTOR_DB_PATH):
|
| 25 |
+
# Create new vector database
|
| 26 |
+
loader = DirectoryLoader(PDF_DIR, glob="*.pdf", loader_cls=PyPDFLoader)
|
| 27 |
+
documents = loader.load()
|
| 28 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
|
| 29 |
+
texts = text_splitter.split_documents(documents)
|
| 30 |
+
embeddings = HuggingFaceBgeEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
| 31 |
+
vector_db = Chroma.from_documents(texts, embeddings, persist_directory=VECTOR_DB_PATH)
|
| 32 |
+
else:
|
| 33 |
+
# Load existing database
|
| 34 |
+
embeddings = HuggingFaceBgeEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
| 35 |
+
vector_db = Chroma(persist_directory=VECTOR_DB_PATH, embedding_function=embeddings)
|
| 36 |
+
|
| 37 |
+
retriever = vector_db.as_retriever()
|
| 38 |
+
|
| 39 |
+
# Load food classification model
|
| 40 |
+
food_classifier = pipeline(
|
| 41 |
+
"image-classification",
|
| 42 |
+
model=MODEL_PATH,
|
| 43 |
+
device_map="auto"
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
def classify_food(image):
|
| 47 |
+
"""Classify food images with confidence thresholding"""
|
| 48 |
+
if image is None:
|
| 49 |
+
return None, 0.0
|
| 50 |
+
results = food_classifier(image)
|
| 51 |
+
if not results:
|
| 52 |
+
return None, 0.0
|
| 53 |
+
top_result = results[0]
|
| 54 |
+
label = top_result["label"]
|
| 55 |
+
score = top_result["score"]
|
| 56 |
+
if score < 0.3 or "non-food" in label.lower():
|
| 57 |
+
return None, score
|
| 58 |
+
return label, score
|
| 59 |
+
|
| 60 |
+
def format_history(chat_history, max_exchanges=5):
|
| 61 |
+
"""Format conversation history for context"""
|
| 62 |
+
recent_history = chat_history[-max_exchanges:]
|
| 63 |
+
return "\n".join(
|
| 64 |
+
f"User: {user}\nAssistant: {assistant}"
|
| 65 |
+
for user, assistant in recent_history
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
def calculate_metrics(status, pre_weight, current_weight, height,
|
| 69 |
+
gest_age=None, time_since_delivery=None, breastfeeding=None):
|
| 70 |
+
"""Calculate pregnancy/postpartum metrics"""
|
| 71 |
+
if None in [pre_weight, current_weight, height]:
|
| 72 |
+
return "Missing required fields: weight and height"
|
| 73 |
+
|
| 74 |
+
height_m = height / 100
|
| 75 |
+
pre_bmi = pre_weight / (height_m ** 2)
|
| 76 |
+
|
| 77 |
+
if status == "Pregnant":
|
| 78 |
+
if not gest_age or not (0 <= gest_age <= 40):
|
| 79 |
+
return "Invalid gestational age (0-40 weeks)"
|
| 80 |
+
|
| 81 |
+
# BMI-based recommendations
|
| 82 |
+
bmi_ranges = [
|
| 83 |
+
(18.5, 12.5, 18),
|
| 84 |
+
(25, 11.5, 16),
|
| 85 |
+
(30, 7, 11.5),
|
| 86 |
+
(float('inf'), 5, 9)
|
| 87 |
+
]
|
| 88 |
+
for max_bmi, min_gain, max_gain in bmi_ranges:
|
| 89 |
+
if pre_bmi < max_bmi:
|
| 90 |
+
break
|
| 91 |
+
|
| 92 |
+
current_gain = current_weight - pre_weight
|
| 93 |
+
expected_min = (min_gain / 40) * gest_age
|
| 94 |
+
expected_max = (max_gain / 40) * gest_age
|
| 95 |
+
|
| 96 |
+
if current_gain < expected_min:
|
| 97 |
+
advice = "Consider nutritional counseling"
|
| 98 |
+
elif current_gain > expected_max:
|
| 99 |
+
advice = "Consult your healthcare provider"
|
| 100 |
+
else:
|
| 101 |
+
advice = "Good progress! Maintain balanced diet"
|
| 102 |
+
|
| 103 |
+
return (f"Pre-BMI: {pre_bmi:.1f}\nWeek {gest_age} recommendation: "
|
| 104 |
+
f"{expected_min:.1f}-{expected_max:.1f} kg\n"
|
| 105 |
+
f"Your gain: {current_gain:.1f} kg\n{advice}")
|
| 106 |
+
|
| 107 |
+
elif status == "Postpartum":
|
| 108 |
+
if None in [time_since_delivery, breastfeeding]:
|
| 109 |
+
return "Missing postpartum details"
|
| 110 |
+
|
| 111 |
+
current_bmi = current_weight / (height_m ** 2)
|
| 112 |
+
if breastfeeding == "Yes":
|
| 113 |
+
advice = ("Aim for 0.5-1 kg/month loss while breastfeeding\n"
|
| 114 |
+
"Focus on nutrient-dense foods")
|
| 115 |
+
else:
|
| 116 |
+
advice = "Gradual weight loss through diet and exercise"
|
| 117 |
+
|
| 118 |
+
return (f"Current BMI: {current_bmi:.1f}\n"
|
| 119 |
+
f"{time_since_delivery} weeks postpartum\n{advice}")
|
| 120 |
+
|
| 121 |
+
return "Select pregnancy status"
|
| 122 |
+
|
| 123 |
+
def chat_function(user_input, image, chat_history):
|
| 124 |
+
"""Handle chat interactions"""
|
| 125 |
+
history_str = format_history(chat_history)
|
| 126 |
+
|
| 127 |
+
# Crisis detection
|
| 128 |
+
crisis_terms = {
|
| 129 |
+
"suicide", "self-harm", "kill myself", "hopeless",
|
| 130 |
+
"panic attack", "worthless", "end it all"
|
| 131 |
+
}
|
| 132 |
+
if any(term in user_input.lower() for term in crisis_terms):
|
| 133 |
+
return chat_history + [(user_input, crisis_response)]
|
| 134 |
+
|
| 135 |
+
# Process image or text input
|
| 136 |
+
if image:
|
| 137 |
+
food_label, confidence = classify_food(image)
|
| 138 |
+
if food_label:
|
| 139 |
+
context = f"User uploaded {food_label} image. {user_input}"
|
| 140 |
+
else:
|
| 141 |
+
return chat_history + [(user_input, "Couldn't identify food in image")]
|
| 142 |
+
else:
|
| 143 |
+
context = user_input
|
| 144 |
+
|
| 145 |
+
# Retrieve relevant documents
|
| 146 |
+
docs = retriever.get_relevant_documents(context)
|
| 147 |
+
context_str = "\n".join(d.page_content for d in docs[:3])
|
| 148 |
+
|
| 149 |
+
# Generate response
|
| 150 |
+
prompt = f"""Context from documents:
|
| 151 |
+
{context_str}
|
| 152 |
+
|
| 153 |
+
Recent conversation:
|
| 154 |
+
{history_str}
|
| 155 |
+
|
| 156 |
+
User: {context}
|
| 157 |
+
Assistant:"""
|
| 158 |
+
|
| 159 |
+
response = llm.invoke(prompt).content
|
| 160 |
+
return chat_history + [(user_input, response)]
|
| 161 |
+
|
| 162 |
+
# Crisis response template
|
| 163 |
+
crisis_response = """🚨 Immediate Help Resources:
|
| 164 |
+
- India: Vandrevala Foundation - 1860 266 2345
|
| 165 |
+
- US: 988 Suicide & Crisis Lifeline
|
| 166 |
+
- UK: Samaritans - 116 123
|
| 167 |
+
- Worldwide: https://findahelpline.com"""
|
| 168 |
+
|
| 169 |
+
# Gradio Interface
|
| 170 |
+
with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
|
| 171 |
+
gr.Markdown("# 🤰 Maternal Wellness Companion")
|
| 172 |
+
|
| 173 |
+
with gr.Row():
|
| 174 |
+
with gr.Column(scale=2):
|
| 175 |
+
chatbot = gr.Chatbot(height=500)
|
| 176 |
+
input_txt = gr.Textbox(placeholder="Ask about pregnancy health...")
|
| 177 |
+
input_img = gr.Image(type="pil", label="Upload Food Image")
|
| 178 |
+
btn_send = gr.Button("Send", variant="primary")
|
| 179 |
+
|
| 180 |
+
with gr.Column(scale=1):
|
| 181 |
+
gr.Markdown("## Health Metrics")
|
| 182 |
+
status = gr.Radio(["Pregnant", "Postpartum"], label="Status")
|
| 183 |
+
pre_weight = gr.Number(label="Pre-pregnancy Weight (kg)")
|
| 184 |
+
current_weight = gr.Number(label="Current Weight (kg)")
|
| 185 |
+
height = gr.Number(label="Height (cm)")
|
| 186 |
+
gest_age = gr.Number(visible=False)
|
| 187 |
+
postpartum_time = gr.Number(visible=False)
|
| 188 |
+
breastfeeding = gr.Radio(["Yes", "No"], visible=False)
|
| 189 |
+
btn_calculate = gr.Button("Calculate", variant="secondary")
|
| 190 |
+
|
| 191 |
+
# Event handling
|
| 192 |
+
status.change(
|
| 193 |
+
lambda s: (
|
| 194 |
+
gr.update(visible=s=="Pregnant"),
|
| 195 |
+
gr.update(visible=s=="Postpartum"),
|
| 196 |
+
gr.update(visible=s=="Postpartum")
|
| 197 |
+
),
|
| 198 |
+
inputs=status,
|
| 199 |
+
outputs=[gest_age, postpartum_time, breastfeeding]
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
btn_send.click(
|
| 203 |
+
chat_function,
|
| 204 |
+
[input_txt, input_img, chatbot],
|
| 205 |
+
[chatbot],
|
| 206 |
+
queue=False
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
btn_calculate.click(
|
| 210 |
+
calculate_metrics,
|
| 211 |
+
[status, pre_weight, current_weight, height,
|
| 212 |
+
gest_age, postpartum_time, breastfeeding],
|
| 213 |
+
chatbot
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
demo.launch(debug=False)
|