Siddu2004-2006 commited on
Commit
ce33186
·
verified ·
1 Parent(s): 8023316

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +216 -0
app.py CHANGED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ from langchain_groq import ChatGroq
3
+ from langchain.vectorstores import Chroma
4
+ from langchain.document_loaders import PyPDFLoader, DirectoryLoader
5
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
6
+ from langchain.embeddings import HuggingFaceBgeEmbeddings
7
+ from transformers import pipeline
8
+ import gradio as gr
9
+ import os
10
+
11
+ # Initialize Groq with environment variable
12
+ llm = ChatGroq(
13
+ temperature=0.7,
14
+ groq_api_key=os.environ.get("Groq_API_KEY"), # Set in HF Secrets
15
+ model_name="llama-3.3-70b-versatile"
16
+ )
17
+
18
+ # Configure paths for Hugging Face Space
19
+ VECTOR_DB_PATH = "./chroma_db"
20
+ PDF_DIR = "./Pregnancy"
21
+ MODEL_PATH = "./indian_food_model"
22
+
23
+ # Initialize or load vector database
24
+ if not os.path.exists(VECTOR_DB_PATH):
25
+ # Create new vector database
26
+ loader = DirectoryLoader(PDF_DIR, glob="*.pdf", loader_cls=PyPDFLoader)
27
+ documents = loader.load()
28
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
29
+ texts = text_splitter.split_documents(documents)
30
+ embeddings = HuggingFaceBgeEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
31
+ vector_db = Chroma.from_documents(texts, embeddings, persist_directory=VECTOR_DB_PATH)
32
+ else:
33
+ # Load existing database
34
+ embeddings = HuggingFaceBgeEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
35
+ vector_db = Chroma(persist_directory=VECTOR_DB_PATH, embedding_function=embeddings)
36
+
37
+ retriever = vector_db.as_retriever()
38
+
39
+ # Load food classification model
40
+ food_classifier = pipeline(
41
+ "image-classification",
42
+ model=MODEL_PATH,
43
+ device_map="auto"
44
+ )
45
+
46
+ def classify_food(image):
47
+ """Classify food images with confidence thresholding"""
48
+ if image is None:
49
+ return None, 0.0
50
+ results = food_classifier(image)
51
+ if not results:
52
+ return None, 0.0
53
+ top_result = results[0]
54
+ label = top_result["label"]
55
+ score = top_result["score"]
56
+ if score < 0.3 or "non-food" in label.lower():
57
+ return None, score
58
+ return label, score
59
+
60
+ def format_history(chat_history, max_exchanges=5):
61
+ """Format conversation history for context"""
62
+ recent_history = chat_history[-max_exchanges:]
63
+ return "\n".join(
64
+ f"User: {user}\nAssistant: {assistant}"
65
+ for user, assistant in recent_history
66
+ )
67
+
68
+ def calculate_metrics(status, pre_weight, current_weight, height,
69
+ gest_age=None, time_since_delivery=None, breastfeeding=None):
70
+ """Calculate pregnancy/postpartum metrics"""
71
+ if None in [pre_weight, current_weight, height]:
72
+ return "Missing required fields: weight and height"
73
+
74
+ height_m = height / 100
75
+ pre_bmi = pre_weight / (height_m ** 2)
76
+
77
+ if status == "Pregnant":
78
+ if not gest_age or not (0 <= gest_age <= 40):
79
+ return "Invalid gestational age (0-40 weeks)"
80
+
81
+ # BMI-based recommendations
82
+ bmi_ranges = [
83
+ (18.5, 12.5, 18),
84
+ (25, 11.5, 16),
85
+ (30, 7, 11.5),
86
+ (float('inf'), 5, 9)
87
+ ]
88
+ for max_bmi, min_gain, max_gain in bmi_ranges:
89
+ if pre_bmi < max_bmi:
90
+ break
91
+
92
+ current_gain = current_weight - pre_weight
93
+ expected_min = (min_gain / 40) * gest_age
94
+ expected_max = (max_gain / 40) * gest_age
95
+
96
+ if current_gain < expected_min:
97
+ advice = "Consider nutritional counseling"
98
+ elif current_gain > expected_max:
99
+ advice = "Consult your healthcare provider"
100
+ else:
101
+ advice = "Good progress! Maintain balanced diet"
102
+
103
+ return (f"Pre-BMI: {pre_bmi:.1f}\nWeek {gest_age} recommendation: "
104
+ f"{expected_min:.1f}-{expected_max:.1f} kg\n"
105
+ f"Your gain: {current_gain:.1f} kg\n{advice}")
106
+
107
+ elif status == "Postpartum":
108
+ if None in [time_since_delivery, breastfeeding]:
109
+ return "Missing postpartum details"
110
+
111
+ current_bmi = current_weight / (height_m ** 2)
112
+ if breastfeeding == "Yes":
113
+ advice = ("Aim for 0.5-1 kg/month loss while breastfeeding\n"
114
+ "Focus on nutrient-dense foods")
115
+ else:
116
+ advice = "Gradual weight loss through diet and exercise"
117
+
118
+ return (f"Current BMI: {current_bmi:.1f}\n"
119
+ f"{time_since_delivery} weeks postpartum\n{advice}")
120
+
121
+ return "Select pregnancy status"
122
+
123
+ def chat_function(user_input, image, chat_history):
124
+ """Handle chat interactions"""
125
+ history_str = format_history(chat_history)
126
+
127
+ # Crisis detection
128
+ crisis_terms = {
129
+ "suicide", "self-harm", "kill myself", "hopeless",
130
+ "panic attack", "worthless", "end it all"
131
+ }
132
+ if any(term in user_input.lower() for term in crisis_terms):
133
+ return chat_history + [(user_input, crisis_response)]
134
+
135
+ # Process image or text input
136
+ if image:
137
+ food_label, confidence = classify_food(image)
138
+ if food_label:
139
+ context = f"User uploaded {food_label} image. {user_input}"
140
+ else:
141
+ return chat_history + [(user_input, "Couldn't identify food in image")]
142
+ else:
143
+ context = user_input
144
+
145
+ # Retrieve relevant documents
146
+ docs = retriever.get_relevant_documents(context)
147
+ context_str = "\n".join(d.page_content for d in docs[:3])
148
+
149
+ # Generate response
150
+ prompt = f"""Context from documents:
151
+ {context_str}
152
+
153
+ Recent conversation:
154
+ {history_str}
155
+
156
+ User: {context}
157
+ Assistant:"""
158
+
159
+ response = llm.invoke(prompt).content
160
+ return chat_history + [(user_input, response)]
161
+
162
+ # Crisis response template
163
+ crisis_response = """🚨 Immediate Help Resources:
164
+ - India: Vandrevala Foundation - 1860 266 2345
165
+ - US: 988 Suicide & Crisis Lifeline
166
+ - UK: Samaritans - 116 123
167
+ - Worldwide: https://findahelpline.com"""
168
+
169
+ # Gradio Interface
170
+ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
171
+ gr.Markdown("# 🤰 Maternal Wellness Companion")
172
+
173
+ with gr.Row():
174
+ with gr.Column(scale=2):
175
+ chatbot = gr.Chatbot(height=500)
176
+ input_txt = gr.Textbox(placeholder="Ask about pregnancy health...")
177
+ input_img = gr.Image(type="pil", label="Upload Food Image")
178
+ btn_send = gr.Button("Send", variant="primary")
179
+
180
+ with gr.Column(scale=1):
181
+ gr.Markdown("## Health Metrics")
182
+ status = gr.Radio(["Pregnant", "Postpartum"], label="Status")
183
+ pre_weight = gr.Number(label="Pre-pregnancy Weight (kg)")
184
+ current_weight = gr.Number(label="Current Weight (kg)")
185
+ height = gr.Number(label="Height (cm)")
186
+ gest_age = gr.Number(visible=False)
187
+ postpartum_time = gr.Number(visible=False)
188
+ breastfeeding = gr.Radio(["Yes", "No"], visible=False)
189
+ btn_calculate = gr.Button("Calculate", variant="secondary")
190
+
191
+ # Event handling
192
+ status.change(
193
+ lambda s: (
194
+ gr.update(visible=s=="Pregnant"),
195
+ gr.update(visible=s=="Postpartum"),
196
+ gr.update(visible=s=="Postpartum")
197
+ ),
198
+ inputs=status,
199
+ outputs=[gest_age, postpartum_time, breastfeeding]
200
+ )
201
+
202
+ btn_send.click(
203
+ chat_function,
204
+ [input_txt, input_img, chatbot],
205
+ [chatbot],
206
+ queue=False
207
+ )
208
+
209
+ btn_calculate.click(
210
+ calculate_metrics,
211
+ [status, pre_weight, current_weight, height,
212
+ gest_age, postpartum_time, breastfeeding],
213
+ chatbot
214
+ )
215
+
216
+ demo.launch(debug=False)