Driftline / app.py
jmisak's picture
Upload 2 files
aa32e5f verified
import gradio as gr
import yaml
import json
import os
import traceback
import matplotlib.pyplot as plt
import numpy as np
from engine.loader import load_persona
from engine.drift import apply_stimuli
from engine.responder import generate_response
from engine.utils import safe_log
from engine.logger import log_transcript
# Paths
persona_dir = "./personas"
stimuli_path = "./stimuli/events.json"
error_log_path = "./driftline_errors.log"
# Load available personas
def get_persona_choices():
return [f for f in os.listdir(persona_dir) if f.endswith(".yml")]
# Extract specialties from personas
def get_specialties():
"""Extract unique specialties from all persona files"""
specialties = set()
for filename in os.listdir(persona_dir):
if filename.endswith(".yml"):
try:
persona_path = os.path.join(persona_dir, filename)
with open(persona_path, "r") as f:
persona_data = yaml.safe_load(f)
if "role" in persona_data:
specialties.add(persona_data["role"])
except Exception as e:
safe_log("Specialty extraction error", f"{filename}: {e}")
return ["All Specialties"] + sorted(list(specialties))
# Filter personas by specialty
def get_personas_by_specialty(specialty):
"""Get personas filtered by specialty"""
if specialty == "All Specialties":
return get_persona_choices()
filtered = []
for filename in os.listdir(persona_dir):
if filename.endswith(".yml"):
try:
persona_path = os.path.join(persona_dir, filename)
with open(persona_path, "r") as f:
persona_data = yaml.safe_load(f)
if persona_data.get("role") == specialty:
filtered.append(filename)
except Exception as e:
safe_log("Persona filter error", f"{filename}: {e}")
return filtered
# Load available stimuli
def get_stimuli_choices():
try:
with open(stimuli_path, "r") as f:
events = json.load(f)
# Return choices with descriptions for better UX
choices = []
for e in events:
event_name = e["event"]
description = e.get("description", "")
if description:
choices.append(f"{event_name} - {description}")
else:
choices.append(event_name)
return choices
except Exception as e:
safe_log("Stimuli load error", str(e))
return []
# Extract event name from dropdown choice
def extract_event_name(choice):
"""Extract the event name from the dropdown choice (handles both formats)"""
if " - " in choice:
return choice.split(" - ")[0]
return choice
# Generate radar chart for traits
def plot_traits(state):
traits = ["innovation", "openness", "risk_tolerance", "peer_influence"]
values = [state.get(t, 0.0) for t in traits]
angles = np.linspace(0, 2 * np.pi, len(traits), endpoint=False).tolist()
values += values[:1]
angles += angles[:1]
fig, ax = plt.subplots(figsize=(4, 4), subplot_kw=dict(polar=True))
ax.plot(angles, values, color="blue", linewidth=2)
ax.fill(angles, values, color="blue", alpha=0.25)
ax.set_xticks(angles[:-1])
ax.set_xticklabels(traits)
ax.set_yticklabels([])
ax.set_title("Dynamic Trait Profile", fontsize=12)
fig.tight_layout()
chart_path = "./trait_chart.png"
fig.savefig(chart_path)
plt.close(fig)
return chart_path
# Main simulation function
def simulate(prompt, selected_event, selected_persona_file):
try:
persona_path = os.path.join(persona_dir, selected_persona_file)
persona = load_persona(persona_path)
# Extract event name from dropdown choice
event_name = extract_event_name(selected_event)
with open(stimuli_path, "r") as f:
events = json.load(f)
event = next((e for e in events if e["event"] == event_name), None)
if event:
persona = apply_stimuli(persona, event)
response = generate_response(prompt, persona, event)
state_yaml = yaml.dump(persona["dynamic_state"], sort_keys=False)
chart_path = plot_traits(persona["dynamic_state"])
# ? Log transcript here, after response is generated
from engine.logger import log_transcript
transcript_path = log_transcript(persona, prompt, selected_event, response)
return response, state_yaml, chart_path
except Exception as e:
error_msg = traceback.format_exc()
safe_log("Simulation error", error_msg)
return "[ERROR] Simulation failed. Check logs.", "", None
# Gradio UI
with gr.Blocks(
title="Driftline HCP Simulator",
head="""
<link rel="manifest" href="/file=manifest.json">
<meta name="apple-mobile-web-app-capable" content="yes">
<meta name="apple-mobile-web-app-status-bar-style" content="black-translucent">
<meta name="apple-mobile-web-app-title" content="Driftline">
<link rel="apple-touch-icon" href="/file=driftline.png">
<meta name="mobile-web-app-capable" content="yes">
<meta name="theme-color" content="#2563eb">
"""
) as ui:
gr.Markdown("## 🧠 Driftline: Adaptive HCP Simulation")
gr.Markdown("Simulate how healthcare personas evolve in response to market stimuli.")
gr.Markdown("💡 **Tip:** Install this app to your device for quick access!")
with gr.Row():
specialty_selector = gr.Dropdown(
label="Medical Specialty",
choices=get_specialties(),
value="All Specialties",
interactive=True
)
persona_selector = gr.Dropdown(
label="Choose Persona",
choices=get_persona_choices(),
value="hcp_drift_trailblazer.yml",
interactive=True
)
event_selector = gr.Dropdown(
label="Market Stimulus",
choices=get_stimuli_choices(),
value=get_stimuli_choices()[0] if get_stimuli_choices() else None
)
prompt = gr.Textbox(label="Interviewer Prompt", lines=2, placeholder="Ask about a new therapy, trial data, or prescribing behavior...")
with gr.Row():
simulate_btn = gr.Button("Run Simulation")
clear_btn = gr.Button("Clear")
response_output = gr.Textbox(label="Simulated HCP Response", lines=6)
state_output = gr.Textbox(label="Updated Persona State", lines=10)
trait_chart = gr.Image(label="Trait Radar Chart")
# Update persona dropdown when specialty changes
def update_personas(specialty):
personas = get_personas_by_specialty(specialty)
return gr.Dropdown(choices=personas, value=personas[0] if personas else None)
specialty_selector.change(
fn=update_personas,
inputs=[specialty_selector],
outputs=[persona_selector]
)
simulate_btn.click(
fn=simulate,
inputs=[prompt, event_selector, persona_selector],
outputs=[response_output, state_output, trait_chart]
)
clear_btn.click(
fn=lambda: ("", "", None),
inputs=[],
outputs=[response_output, state_output, trait_chart]
)
# Launch with PWA support
ui.launch(
pwa=True,
favicon_path="driftline.png"
)