Spaces:
Sleeping
Sleeping
| from typing import Any, Literal | |
| from pydantic import BaseModel, Field | |
| class TaskData(BaseModel): | |
| name: str | None = Field( | |
| description="Name of the task.", default=None, examples=["Check input safety"] | |
| ) | |
| run_id: str = Field( | |
| description="ID of the task run to pair state updates to.", | |
| default="", | |
| examples=["847c6285-8fc9-4560-a83f-4e6285809254"], | |
| ) | |
| state: Literal["new", "running", "complete"] | None = Field( | |
| description="Current state of given task instance.", | |
| default=None, | |
| examples=["running"], | |
| ) | |
| result: Literal["success", "error"] | None = Field( | |
| description="Result of given task instance.", | |
| default=None, | |
| examples=["running"], | |
| ) | |
| data: dict[str, Any] = Field( | |
| description="Additional data generated by the task.", | |
| default={}, | |
| ) | |
| def completed(self) -> bool: | |
| return self.state == "complete" | |
| def completed_with_error(self) -> bool: | |
| return self.state == "complete" and self.result == "error" | |
| class TaskDataStatus: | |
| def __init__(self) -> None: | |
| import streamlit as st | |
| self.status = st.status("") | |
| self.current_task_data: dict[str, TaskData] = {} | |
| def add_and_draw_task_data(self, task_data: TaskData) -> None: | |
| status = self.status | |
| status_str = f"Task **{task_data.name}** " | |
| match task_data.state: | |
| case "new": | |
| status_str += "has :blue[started]. Input:" | |
| case "running": | |
| status_str += "wrote:" | |
| case "complete": | |
| if task_data.result == "success": | |
| status_str += ":green[completed successfully]. Output:" | |
| else: | |
| status_str += ":red[ended with error]. Output:" | |
| status.write(status_str) | |
| status.write(task_data.data) | |
| status.write("---") | |
| if task_data.run_id not in self.current_task_data: | |
| # Status label always shows the last newly started task | |
| status.update(label=f"""Task: {task_data.name}""") | |
| self.current_task_data[task_data.run_id] = task_data | |
| if all(entry.completed() for entry in self.current_task_data.values()): | |
| # Status is "error" if any task has errored | |
| if any(entry.completed_with_error() for entry in self.current_task_data.values()): | |
| state = "error" | |
| # Status is "complete" if all tasks have completed successfully | |
| else: | |
| state = "complete" | |
| # Status is "running" until all tasks have completed | |
| else: | |
| state = "running" | |
| status.update(state=state) # type: ignore[arg-type] | |