HCPDigitalTwin / app.py
jmisak's picture
Update app.py
f7868c9 verified
import gradio as gr
import yaml
import json
import os
import traceback
import matplotlib.pyplot as plt
import numpy as np
from engine.loader import load_persona
from engine.drift import apply_stimuli
from engine.responder import generate_response
from engine.utils import safe_log
from engine.logger import log_transcript
# Paths
persona_dir = "./personas"
stimuli_path = "./stimuli/events.json"
error_log_path = "./driftline_errors.log"
# Load available personas
def get_persona_choices():
return [f for f in os.listdir(persona_dir) if f.endswith(".yml")]
# Load available stimuli
def get_stimuli_choices():
try:
with open(stimuli_path, "r") as f:
events = json.load(f)
return [e["event"] for e in events]
except Exception as e:
safe_log("Stimuli load error", str(e))
return []
# Generate radar chart for traits
def plot_traits(state):
traits = ["innovation", "openness", "risk_tolerance", "peer_influence"]
values = [state.get(t, 0.0) for t in traits]
angles = np.linspace(0, 2 * np.pi, len(traits), endpoint=False).tolist()
values += values[:1]
angles += angles[:1]
fig, ax = plt.subplots(figsize=(4, 4), subplot_kw=dict(polar=True))
ax.plot(angles, values, color="blue", linewidth=2)
ax.fill(angles, values, color="blue", alpha=0.25)
ax.set_xticks(angles[:-1])
ax.set_xticklabels(traits)
ax.set_yticklabels([])
ax.set_title("Dynamic Trait Profile", fontsize=12)
fig.tight_layout()
chart_path = "./trait_chart.png"
fig.savefig(chart_path)
plt.close(fig)
return chart_path
# Main simulation function
def simulate(prompt, selected_event, selected_persona_file):
try:
persona_path = os.path.join(persona_dir, selected_persona_file)
persona = load_persona(persona_path)
with open(stimuli_path, "r") as f:
events = json.load(f)
event = next((e for e in events if e["event"] == selected_event), None)
if event:
persona = apply_stimuli(persona, event)
response = generate_response(prompt, persona)
state_yaml = yaml.dump(persona["dynamic_state"], sort_keys=False)
chart_path = plot_traits(persona["dynamic_state"])
# ? Log transcript here, after response is generated
from engine.logger import log_transcript
transcript_path = log_transcript(persona, prompt, selected_event, response)
return response, state_yaml, chart_path
except Exception as e:
error_msg = traceback.format_exc()
safe_log("Simulation error", error_msg)
return "[ERROR] Simulation failed. Check logs.", "", None
# Gradio UI
with gr.Blocks(title="Driftline HCP Simulator") as ui:
gr.Markdown("## 🧠 Driftline: Adaptive HCP Simulation")
gr.Markdown("Simulate how healthcare personas evolve in response to market stimuli.")
with gr.Row():
persona_files = get_persona_choices()
default_persona = persona_files[0] if persona_files else None
persona_selector = gr.Dropdown(
label="Choose Persona",
choices=persona_files,
value=default_persona,
allow_custom_value=False
)
event_selector = gr.Dropdown(label="Market Stimulus", choices=get_stimuli_choices(), value="FDA_approval")
prompt = gr.Textbox(label="Interviewer Prompt", lines=2, placeholder="Ask about a new therapy, trial data, or prescribing behavior...")
with gr.Row():
simulate_btn = gr.Button("Run Simulation")
clear_btn = gr.Button("Clear")
response_output = gr.Textbox(label="Simulated HCP Response", lines=6)
state_output = gr.Textbox(label="Updated Persona State", lines=10)
trait_chart = gr.Image(label="Trait Radar Chart")
simulate_btn.click(
fn=simulate,
inputs=[prompt, event_selector, persona_selector],
outputs=[response_output, state_output, trait_chart]
)
clear_btn.click(
fn=lambda: ("", "", None),
inputs=[],
outputs=[response_output, state_output, trait_chart]
)
ui.launch()