Sk1306's picture
Create app.py
b0b8d4d verified
import numpy as np
from collections import deque
import gradio as gr
# ----------------------------------
# Dosha Agent Class
# ----------------------------------
class DoshaStateTrackingAgent:
def __init__(
self,
initial_state,
initial_confidence=0.85,
alpha=0.6,
history_window=7
):
self.alpha = alpha
self.state = initial_state.copy()
self.confidence = initial_confidence
self.baseline = initial_state.copy()
self.history = deque(maxlen=history_window)
self.trend_history = deque(maxlen=3)
def _normalize(self, obs):
total = sum(obs.values())
return {k: v / total for k, v in obs.items()}
def observe(self, observation):
obs = self._normalize(observation)
self.history.append(obs)
return obs
def update_state(self, obs):
for d in self.state:
self.state[d] = (
self.alpha * self.state[d] +
(1 - self.alpha) * obs[d]
)
self._update_confidence()
return self.state
def _update_confidence(self):
variance = np.var(list(self.state.values()))
self.confidence = max(0.4, min(0.95, 1 - variance))
def compute_imbalance(self):
imbalance = {
d: abs(self.state[d] - self.baseline[d])
for d in self.state
}
severity = self._bucket_severity(max(imbalance.values()))
return imbalance, severity
def _bucket_severity(self, value):
if value < 0.05:
return "mild"
elif value < 0.12:
return "moderate"
else:
return "severe"
def detect_trends(self):
dominant = max(self.state, key=self.state.get)
self.trend_history.append(dominant)
if len(self.trend_history) < 3:
return "stable"
if len(set(self.trend_history)) == 1:
return f"{dominant}_rising"
return "mixed"
def generate_triggers(self, severity, trend):
triggers = []
if severity == "severe":
triggers.append("high_imbalance_alert")
if "rising" in trend:
triggers.append(f"{trend}_3_days")
return triggers
def step(self, observation):
obs = self.observe(observation)
state = self.update_state(obs)
imbalance, severity = self.compute_imbalance()
trend = self.detect_trends()
triggers = self.generate_triggers(severity, trend)
return {
"State": state,
"Imbalance": imbalance,
"Severity": severity,
"Trend": trend,
"Confidence": round(self.confidence, 3),
"Triggers": triggers
}
# ----------------------------------
# Global Agent (persistent state)
# ----------------------------------
agent = DoshaStateTrackingAgent(
initial_state={"vata": 0.4, "pitta": 0.35, "kapha": 0.25}
)
# ----------------------------------
# Function for UI
# ----------------------------------
def predict(vata, pitta, kapha):
obs = {"vata": vata, "pitta": pitta, "kapha": kapha}
output = agent.step(obs)
return (
str(output["State"]),
str(output["Imbalance"]),
output["Severity"],
output["Trend"],
output["Confidence"],
str(output["Triggers"])
)
# ----------------------------------
# Gradio UI
# ----------------------------------
iface = gr.Interface(
fn=predict,
inputs=[
gr.Slider(0, 1, value=0.5, label="Vata"),
gr.Slider(0, 1, value=0.3, label="Pitta"),
gr.Slider(0, 1, value=0.2, label="Kapha"),
],
outputs=[
gr.Textbox(label="State Vector"),
gr.Textbox(label="Imbalance"),
gr.Textbox(label="Severity"),
gr.Textbox(label="Trend"),
gr.Textbox(label="Confidence"),
gr.Textbox(label="Triggers"),
],
title="🧠 Dosha State Tracking Agent",
description="Track Vata, Pitta, Kapha changes over time using EMA-based AI agent"
)
iface.launch()