|
|
from typing import Any, Literal |
|
|
|
|
|
from pydantic import BaseModel, Field |
|
|
|
|
|
|
|
|
class TaskData(BaseModel): |
|
|
name: str | None = Field( |
|
|
description="Name of the task.", default=None, examples=["Check input safety"] |
|
|
) |
|
|
run_id: str = Field( |
|
|
description="ID of the task run to pair state updates to.", |
|
|
default="", |
|
|
examples=["847c6285-8fc9-4560-a83f-4e6285809254"], |
|
|
) |
|
|
state: Literal["new", "running", "complete"] | None = Field( |
|
|
description="Current state of given task instance.", |
|
|
default=None, |
|
|
examples=["running"], |
|
|
) |
|
|
result: Literal["success", "error"] | None = Field( |
|
|
description="Result of given task instance.", |
|
|
default=None, |
|
|
examples=["running"], |
|
|
) |
|
|
data: dict[str, Any] = Field( |
|
|
description="Additional data generated by the task.", |
|
|
default={}, |
|
|
) |
|
|
|
|
|
def completed(self) -> bool: |
|
|
return self.state == "complete" |
|
|
|
|
|
def completed_with_error(self) -> bool: |
|
|
return self.state == "complete" and self.result == "error" |
|
|
|
|
|
|
|
|
class TaskDataStatus: |
|
|
def __init__(self) -> None: |
|
|
import streamlit as st |
|
|
|
|
|
self.status = st.status("") |
|
|
self.current_task_data: dict[str, TaskData] = {} |
|
|
|
|
|
def add_and_draw_task_data(self, task_data: TaskData) -> None: |
|
|
status = self.status |
|
|
status_str = f"Task **{task_data.name}** " |
|
|
match task_data.state: |
|
|
case "new": |
|
|
status_str += "has :blue[started]. Input:" |
|
|
case "running": |
|
|
status_str += "wrote:" |
|
|
case "complete": |
|
|
if task_data.result == "success": |
|
|
status_str += ":green[completed successfully]. Output:" |
|
|
else: |
|
|
status_str += ":red[ended with error]. Output:" |
|
|
status.write(status_str) |
|
|
status.write(task_data.data) |
|
|
status.write("---") |
|
|
if task_data.run_id not in self.current_task_data: |
|
|
|
|
|
status.update(label=f"""Task: {task_data.name}""") |
|
|
self.current_task_data[task_data.run_id] = task_data |
|
|
if all(entry.completed() for entry in self.current_task_data.values()): |
|
|
|
|
|
if any(entry.completed_with_error() for entry in self.current_task_data.values()): |
|
|
state = "error" |
|
|
|
|
|
else: |
|
|
state = "complete" |
|
|
|
|
|
else: |
|
|
state = "running" |
|
|
status.update(state=state) |
|
|
|