Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import google.generativeai as genai | |
| import os | |
| import json | |
| import re | |
| from duckduckgo_search import DDGS | |
| def dietary_assistant(user_diseases, patient_preferences, user_question): | |
| # Configure Gemini | |
| gemini_api_key = os.getenv("GEMINI_API") | |
| genai.configure(api_key=gemini_api_key) | |
| model = genai.GenerativeModel("gemini-2.0-flash") | |
| # Load disease data | |
| with open("diseases.json", "r") as f: | |
| disease_data = json.load(f) | |
| standard_diseases = [entry["disease"] for entry in disease_data] | |
| def parse_disease_input(text): | |
| return [d.strip().lower() for d in text.split(",") if d.strip()] | |
| user_diseases = parse_disease_input(user_diseases) | |
| # Step 1: Disease mapping prompt | |
| mapping_prompt = f""" | |
| You are a medical assistant. Map the following user-input disease names to the most appropriate standard medical terms from the provided list. | |
| User Input Diseases: | |
| {user_diseases} | |
| Standard Diseases: | |
| {standard_diseases} | |
| Return a JSON object with two keys: | |
| - "mapped": a dictionary where keys are user input and values are matched standard disease names. | |
| - "unmapped": a list of user inputs that cannot be matched confidently. | |
| """ | |
| response = model.generate_content(mapping_prompt) | |
| raw_text = response.text.strip() | |
| if raw_text.startswith("```"): | |
| raw_text = raw_text.strip("`").strip("json").strip() | |
| try: | |
| result_json = json.loads(raw_text) | |
| mapped_diseases = result_json["mapped"] | |
| except Exception as e: | |
| return f"Parsing error: {e}\nRaw: {raw_text}" | |
| # Disease lookup | |
| disease_lookup = {entry["disease"].lower(): entry for entry in disease_data} | |
| # Step 2: Build disease info | |
| def build_disease_info(mapped_diseases): | |
| output = "" | |
| for user_input, standard_name in mapped_diseases.items(): | |
| if not standard_name: | |
| continue | |
| data = disease_lookup.get(standard_name.lower()) | |
| if data: | |
| output += f"Disease: {data['disease']}\n" | |
| output += f"Must Have: {', '.join(data['must_have'])}\n" | |
| output += f"Can Have: {', '.join(data['can_have'])}\n" | |
| output += f"Avoid: {', '.join(data['avoid'])}\n\n" | |
| return output | |
| disease_text = build_disease_info(mapped_diseases) | |
| # Step 3: Prompt for food list JSON | |
| food_prompt = f""" | |
| You are a medical nutritionist assistant. Your task is to generate a personalized dietary recommendation for a patient based on diseases and preferences. | |
| Diseases and their dietary guidelines: | |
| {disease_text} | |
| Patient's personal food preferences: | |
| {patient_preferences} | |
| Please respond with ONLY this valid JSON format: | |
| {{ "must_have": ["item1", "item2"], "can_have": ["item1"], "avoid": ["item1"] }} | |
| """ | |
| response = model.generate_content(food_prompt) | |
| try: | |
| json_text = response.text.strip() | |
| if json_text.startswith("```"): | |
| json_text = json_text.strip("`").strip("json").strip() | |
| diet_info = json.loads(json_text) | |
| except Exception as e: | |
| return f"Failed to parse diet JSON: {e}\nRaw: {response.text}" | |
| # Step 4: ReAct-style Q&A | |
| def web_search(query): | |
| with DDGS() as ddgs: | |
| results = ddgs.text(query) | |
| return "\n".join([r["body"] for r in results][:3]) | |
| history = f""" | |
| You are a ReAct-style dietary agent. Answer the question step-by-step. | |
| Use the tools if needed. The only tool available is web_search. | |
| Patient Info: | |
| MUST HAVE: {', '.join(diet_info['must_have'])} | |
| CAN HAVE: {', '.join(diet_info['can_have'])} | |
| AVOID: {', '.join(diet_info['avoid'])} | |
| Preferences: {patient_preferences} | |
| User Question: {user_question} | |
| Respond with: | |
| Thought: ... | |
| Action: web_search("...") # if needed | |
| Observation: ... # I will fill this in | |
| Final Answer: ... # when ready | |
| """ | |
| for _ in range(5): | |
| response = model.generate_content(history) | |
| text = response.text.strip() | |
| if "Final Answer:" in text: | |
| match = re.search(r"Final Answer:\s*(.*)", text, re.DOTALL) | |
| if match: | |
| answer = match.group(1).strip() | |
| return re.sub(r'[*_`]', '', answer) | |
| match = re.search(r'Action:\s*web_search\("([^"]+)"\)', text) | |
| if match: | |
| query = match.group(1) | |
| result = web_search(query) | |
| history += f"\n{text}\nObservation: {result}" | |
| else: | |
| history += f"\n{text}" | |
| return "❌ Reached max steps without a final answer." | |
| # Gradio UI | |
| disease_input = gr.Textbox(label="Enter Diseases", placeholder="e.g. diabetes, bp, thyroid", lines=1) | |
| demo = gr.Interface( | |
| fn=dietary_assistant, | |
| inputs=[ | |
| disease_input, | |
| gr.Textbox(label="Enter patient food preferences (e.g. vegetarian, no dairy, etc.)"), | |
| gr.Textbox(label="Enter your question about what the patient can/cannot eat") | |
| ], | |
| outputs=gr.Textbox(label="Assistant's Response"), | |
| title="🩺 Medical Nutrition Assistant", | |
| description="Enter diseases, preferences, and your food question. Get diet-safe answers!" | |
| ) | |
| demo.launch() | |