File size: 12,165 Bytes
ce33186
 
32ecd7f
 
 
ce33186
 
 
 
 
 
 
 
2a9d839
a39e621
ce33186
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c8ccaaf
ce33186
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7caeae4
ce33186
7caeae4
 
 
 
 
 
ce33186
7caeae4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ce33186
7caeae4
ce33186
7caeae4
 
 
 
 
 
 
 
 
 
 
ce33186
7caeae4
 
 
 
 
ce33186
 
7caeae4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f88061b
7caeae4
 
 
 
 
 
 
 
ce33186
 
7caeae4
 
 
 
 
 
 
 
ce33186
 
 
7caeae4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ce33186
970f6b2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
# app.py
from langchain_groq import ChatGroq
from langchain_community.vectorstores import Chroma
from langchain_community.document_loaders import PyPDFLoader, DirectoryLoader
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from transformers import pipeline
import gradio as gr
import os

# Initialize Groq with environment variable
llm = ChatGroq(
    temperature=0.7,
    groq_api_key=os.environ.get("Groq_API_Key"),  # Set in HF Secrets
    model_name="meta-llama/llama-4-scout-17b-16e-instruct"
)

# Configure paths for Hugging Face Space
VECTOR_DB_PATH = "./chroma_db"
PDF_DIR = "./Pregnancy"

# Initialize or load vector database
if not os.path.exists(VECTOR_DB_PATH):
    # Create new vector database
    loader = DirectoryLoader(PDF_DIR, glob="*.pdf", loader_cls=PyPDFLoader)
    documents = loader.load()
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
    texts = text_splitter.split_documents(documents)
    embeddings = HuggingFaceBgeEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
    vector_db = Chroma.from_documents(texts, embeddings, persist_directory=VECTOR_DB_PATH)
else:
    # Load existing database
    embeddings = HuggingFaceBgeEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
    vector_db = Chroma(persist_directory=VECTOR_DB_PATH, embedding_function=embeddings)

retriever = vector_db.as_retriever()

# Load food classification model
food_classifier = pipeline(
    "image-classification", 
    model="./indian_food_finetuned_model",
    device_map="auto"
)

def classify_food(image):
    """Classify food images with confidence thresholding"""
    if image is None:
        return None, 0.0
    results = food_classifier(image)
    if not results:
        return None, 0.0
    top_result = results[0]
    label = top_result["label"]
    score = top_result["score"]
    if score < 0.3 or "non-food" in label.lower():
        return None, score
    return label, score

def format_history(chat_history, max_exchanges=5):
    """Format conversation history for context"""
    recent_history = chat_history[-max_exchanges:]
    return "\n".join(
        f"User: {user}\nAssistant: {assistant}" 
        for user, assistant in recent_history
    )

def calculate_metrics(status, pre_weight, current_weight, height, 
                     gest_age=None, time_since_delivery=None, breastfeeding=None):
    """Calculate pregnancy/postpartum metrics"""
    if None in [pre_weight, current_weight, height]:
        return "Missing required fields: weight and height"
    
    height_m = height / 100
    pre_bmi = pre_weight / (height_m ** 2)
    
    if status == "Pregnant":
        if not gest_age or not (0 <= gest_age <= 40):
            return "Invalid gestational age (0-40 weeks)"
        
        # BMI-based recommendations
        bmi_ranges = [
            (18.5, 12.5, 18),
            (25, 11.5, 16),
            (30, 7, 11.5),
            (float('inf'), 5, 9)
        ]
        for max_bmi, min_gain, max_gain in bmi_ranges:
            if pre_bmi < max_bmi:
                break
                
        current_gain = current_weight - pre_weight
        expected_min = (min_gain / 40) * gest_age
        expected_max = (max_gain / 40) * gest_age
        
        if current_gain < expected_min:
            advice = "Consider nutritional counseling"
        elif current_gain > expected_max:
            advice = "Consult your healthcare provider"
        else:
            advice = "Good progress! Maintain balanced diet"
            
        return (f"Pre-BMI: {pre_bmi:.1f}\nWeek {gest_age} recommendation: "
                f"{expected_min:.1f}-{expected_max:.1f} kg\n"
                f"Your gain: {current_gain:.1f} kg\n{advice}")
    
    elif status == "Postpartum":
        if None in [time_since_delivery, breastfeeding]:
            return "Missing postpartum details"
        
        current_bmi = current_weight / (height_m ** 2)
        if breastfeeding == "Yes":
            advice = ("Aim for 0.5-1 kg/month loss while breastfeeding\n"
                     "Focus on nutrient-dense foods")
        else:
            advice = "Gradual weight loss through diet and exercise"
            
        return (f"Current BMI: {current_bmi:.1f}\n"
               f"{time_since_delivery} weeks postpartum\n{advice}")
    
    return "Select pregnancy status"

def chat_function(user_input, image, chat_history):
    """Generate responses based on user input and chat history."""
    history_str = format_history(chat_history)
    crisis_keywords = [
        "suicide", "self-harm", "kill myself", "cutting", "hurt myself", "end my life",
        "hopeless", "worthless", "can’t go on", "panic attack", "feel like dying"
    ]
    newborn_keywords = ["newborn", "baby", "infant", "feeding", "sleep", "colic"]

    if image:
        food_name, confidence = classify_food(image)
        if food_name:
            if user_input:
                prompt = f"""
Previous conversation:
{history_str}

The user uploaded an image of {food_name} and asked: '{user_input}'.

Provide a response tailored to pregnancy or postpartum needs.
"""
            else:
                prompt = f"""
Previous conversation:
{history_str}

The user uploaded an image of {food_name}.

Provide pregnancy-specific nutritional advice.
"""
            response = llm.invoke(prompt).content
        else:
            response = "I couldn’t identify a food item in the image. Please upload a clearer picture."
    else:
        if not user_input.strip():
            response = "Please type a message or upload an image."
        elif any(keyword in user_input.lower() for keyword in crisis_keywords):
            response = """
I'm really sorry you're feeling this way. You’re not alone, and help is available.
Please reach out to someone you trust or contact a helpline:

- 🇮🇳 India: Vandrevala Foundation - 1860 266 2345
- 🇺🇸 USA: National Suicide Prevention Lifeline - 988
- 🇬🇧 UK: Samaritans - 116 123
- 🌍 International: https://findahelpline.com/

If you’re in immediate danger, call emergency services (911/112).
"""
        elif any(keyword in user_input.lower() for keyword in newborn_keywords):
            prompt = f"""
Previous conversation:
{history_str}

The user asked: '{user_input}'.

Provide basic guidance on newborn care.
"""
            response = llm.invoke(prompt).content
        else:
            docs = retriever.get_relevant_documents(user_input)
            context = "\n".join([doc.page_content for doc in docs])
            prompt = f"""
Previous conversation:
{history_str}

Context: {context}

Current question: {user_input}

Assistant:
"""
            response = llm.invoke(prompt).content

    chat_history.append((user_input or "[Image Uploaded]", response))
    return chat_history

# Custom CSS with specified colors
custom_css = """
@import url('https://fonts.googleapis.com/css2?family=Roboto:wght@400;500&display=swap');

/* General layout */
.gradio-container {
    background: #F1D3B2; /* Light peach background */
    font-family: 'Roboto', sans-serif; /* Modern, readable font */
    padding: 20px; /* Add breathing room */
}

/* Chatbot bubble styling */
.chatbot .bubble {
    border-radius: 15px;
    padding: 10px 15px;
    margin: 8px;
    box-shadow: 0 2px 5px rgba(0,0,0,0.1); /* Subtle shadow */
}
.chatbot .bubble:nth-child(odd) {
    background: #F1D3B2; /* Light peach for assistant */
    color: #46211A; /* Dark brown text for contrast */
}
.chatbot .bubble:nth-child(even) {
    background: #D9B08C; /* Slightly darker peach for user */
    color: #46211A;
}

/* Buttons */
button {
    border-radius: 10px !important;
    padding: 10px 20px !important;
    font-size: 16px !important;
    transition: all 0.3s ease !important;
}
button.primary {
    background: #A43820 !important; /* Rusty orange for primary actions */
    color: #F1D3B2 !important; /* Light peach text */
}
button.primary:hover {
    background: #8B2E18 !important; /* Darker orange on hover */
}
button.secondary {
    background: #46211A !important; /* Dark brown for secondary actions */
    color: #F1D3B2 !important;
}
button.secondary:hover {
    background: #301510 !important; /* Darker brown on hover */
}

/* Textbox */
textarea {
    border-radius: 10px !important;
    border: 1px solid #46211A !important; /* Dark brown border */
    padding: 10px !important;
    background: #FFFFFF !important; /* White background for clarity */
    color: #46211A !important;
}

/* Input fields */
.number-input, .radio {
    background: #FFFFFF !important;
    border-radius: 10px !important;
    border: 1px solid #46211A !important;
    padding: 10px !important;
    color: #46211A !important;
}

/* Disclaimer styling */
.disclaimer {
    font-size: 12px;
    color: #46211A; /* Dark brown for readability */
    text-align: center;
}
"""

# Gradio interface
with gr.Blocks(css=custom_css) as demo:
    gr.Markdown("# 🌸FirstSteps-Maternal Wellness Companion 🌸")
    gr.Markdown("""Welcome! I'm here to support you through pregnancy and postpartum with advice on mental health, nutrition, fitness, and newborn care. Ask me anything or upload a food image!""")

    chatbot = gr.Chatbot(
        height=600,
        label="Conversation",
        value=[[None, "Welcome! I'm here to support you through pregnancy and postpartum. Ask me anything or upload a food image for nutritional advice."]]
    )

    with gr.Row():
        with gr.Column(scale=1):
            gr.Markdown("## Chat with Me")
            msg = gr.Textbox(label="Your Message", placeholder="Type your question here...")
            img = gr.Image(label="Upload Food Image", type="pil")
            send_btn = gr.Button("Send")

        with gr.Column(scale=1):
            gr.Markdown("## Pregnancy Metrics")
            status = gr.Radio(["Pregnant", "Postpartum"], label="Your Status")
            pre_weight = gr.Number(label="Pre-pregnancy Weight (kg)")
            current_weight = gr.Number(label="Current Weight (kg)")
            height = gr.Number(label="Height (cm)")
            gest_age = gr.Number(label="Gestational Age (weeks)", visible=False)
            time_since_delivery = gr.Number(label="Time Since Delivery (weeks)", visible=False)
            breastfeeding = gr.Radio(["Yes", "No"], label="Breastfeeding?", visible=False)
            calc_btn = gr.Button("Calculate Metrics")

    with gr.Row():
        clear_btn = gr.Button("Clear Chat")

    def update_visibility(status):
        if status == "Pregnant":
            return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)
        elif status == "Postpartum":
            return gr.update(visible=False), gr.update(visible=True), gr.update(visible=True)
        return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)

    status.change(update_visibility, inputs=status, outputs=[gest_age, time_since_delivery, breastfeeding])

    def handle_send(msg, img, chat_history):
        chat_history = chat_function(msg, img, chat_history)
        return "", None, chat_history

    send_btn.click(handle_send, inputs=[msg, img, chatbot], outputs=[msg, img, chatbot])

    def handle_calc(status, pre_weight, current_weight, height, gest_age, time_since_delivery, breastfeeding, chat_history):
        metrics_response = calculate_metrics(status, pre_weight, current_weight, height, gest_age, time_since_delivery, breastfeeding)
        chat_history.append(("Pregnancy Metrics Calculation", metrics_response))
        return chat_history

    calc_btn.click(handle_calc,
                   inputs=[status, pre_weight, current_weight, height, gest_age, time_since_delivery, breastfeeding, chatbot],
                   outputs=chatbot)

    clear_btn.click(lambda: [], outputs=chatbot)

    gr.HTML('<div class="disclaimer">**Disclaimer**: This app offers general guidance and is not a substitute for professional medical advice. Consult your healthcare provider for personalized recommendations.</div>')

demo.launch(debug=False, share = True)