Spaces:
Sleeping
Sleeping
| import numpy as np | |
| from collections import deque | |
| import gradio as gr | |
| # ---------------------------------- | |
| # Dosha Agent Class | |
| # ---------------------------------- | |
| class DoshaStateTrackingAgent: | |
| def __init__( | |
| self, | |
| initial_state, | |
| initial_confidence=0.85, | |
| alpha=0.6, | |
| history_window=7 | |
| ): | |
| self.alpha = alpha | |
| self.state = initial_state.copy() | |
| self.confidence = initial_confidence | |
| self.baseline = initial_state.copy() | |
| self.history = deque(maxlen=history_window) | |
| self.trend_history = deque(maxlen=3) | |
| def _normalize(self, obs): | |
| total = sum(obs.values()) | |
| return {k: v / total for k, v in obs.items()} | |
| def observe(self, observation): | |
| obs = self._normalize(observation) | |
| self.history.append(obs) | |
| return obs | |
| def update_state(self, obs): | |
| for d in self.state: | |
| self.state[d] = ( | |
| self.alpha * self.state[d] + | |
| (1 - self.alpha) * obs[d] | |
| ) | |
| self._update_confidence() | |
| return self.state | |
| def _update_confidence(self): | |
| variance = np.var(list(self.state.values())) | |
| self.confidence = max(0.4, min(0.95, 1 - variance)) | |
| def compute_imbalance(self): | |
| imbalance = { | |
| d: abs(self.state[d] - self.baseline[d]) | |
| for d in self.state | |
| } | |
| severity = self._bucket_severity(max(imbalance.values())) | |
| return imbalance, severity | |
| def _bucket_severity(self, value): | |
| if value < 0.05: | |
| return "mild" | |
| elif value < 0.12: | |
| return "moderate" | |
| else: | |
| return "severe" | |
| def detect_trends(self): | |
| dominant = max(self.state, key=self.state.get) | |
| self.trend_history.append(dominant) | |
| if len(self.trend_history) < 3: | |
| return "stable" | |
| if len(set(self.trend_history)) == 1: | |
| return f"{dominant}_rising" | |
| return "mixed" | |
| def generate_triggers(self, severity, trend): | |
| triggers = [] | |
| if severity == "severe": | |
| triggers.append("high_imbalance_alert") | |
| if "rising" in trend: | |
| triggers.append(f"{trend}_3_days") | |
| return triggers | |
| def step(self, observation): | |
| obs = self.observe(observation) | |
| state = self.update_state(obs) | |
| imbalance, severity = self.compute_imbalance() | |
| trend = self.detect_trends() | |
| triggers = self.generate_triggers(severity, trend) | |
| return { | |
| "State": state, | |
| "Imbalance": imbalance, | |
| "Severity": severity, | |
| "Trend": trend, | |
| "Confidence": round(self.confidence, 3), | |
| "Triggers": triggers | |
| } | |
| # ---------------------------------- | |
| # Global Agent (persistent state) | |
| # ---------------------------------- | |
| agent = DoshaStateTrackingAgent( | |
| initial_state={"vata": 0.4, "pitta": 0.35, "kapha": 0.25} | |
| ) | |
| # ---------------------------------- | |
| # Function for UI | |
| # ---------------------------------- | |
| def predict(vata, pitta, kapha): | |
| obs = {"vata": vata, "pitta": pitta, "kapha": kapha} | |
| output = agent.step(obs) | |
| return ( | |
| str(output["State"]), | |
| str(output["Imbalance"]), | |
| output["Severity"], | |
| output["Trend"], | |
| output["Confidence"], | |
| str(output["Triggers"]) | |
| ) | |
| # ---------------------------------- | |
| # Gradio UI | |
| # ---------------------------------- | |
| iface = gr.Interface( | |
| fn=predict, | |
| inputs=[ | |
| gr.Slider(0, 1, value=0.5, label="Vata"), | |
| gr.Slider(0, 1, value=0.3, label="Pitta"), | |
| gr.Slider(0, 1, value=0.2, label="Kapha"), | |
| ], | |
| outputs=[ | |
| gr.Textbox(label="State Vector"), | |
| gr.Textbox(label="Imbalance"), | |
| gr.Textbox(label="Severity"), | |
| gr.Textbox(label="Trend"), | |
| gr.Textbox(label="Confidence"), | |
| gr.Textbox(label="Triggers"), | |
| ], | |
| title="🧠 Dosha State Tracking Agent", | |
| description="Track Vata, Pitta, Kapha changes over time using EMA-based AI agent" | |
| ) | |
| iface.launch() |