darkmedia-x-api / engine /src /progress_tracker.rs
cybermedia's picture
Upload folder using huggingface_hub
343eed9 verified
use serde::{Deserialize, Serialize};
use std::path::Path;
use std::time::{SystemTime, UNIX_EPOCH};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TaskProgress {
pub story: String,
pub story_id: String,
pub story_path: String,
pub step: String,
pub progress: u32,
pub narration: String,
pub prompt: String,
pub vfx: String,
pub img: String,
pub voc: String,
pub core_version: String,
pub timestamp: u64,
}
impl TaskProgress {
pub fn new(story_id: String, story_path: String) -> Self {
Self {
story: story_id.clone(),
story_id,
story_path,
step: "Initialisation".to_string(),
progress: 0,
narration: String::new(),
prompt: String::new(),
vfx: "grain".to_string(),
img: "api".to_string(),
voc: "api".to_string(),
core_version: "Rust (Native)".to_string(),
timestamp: current_timestamp(),
}
}
pub fn update_step(&mut self, step: String, progress: u32) {
self.step = step;
self.progress = progress;
self.timestamp = current_timestamp();
}
pub fn set_img(&mut self, img_mode: String) {
self.img = img_mode;
}
pub fn set_voc(&mut self, voc_mode: String) {
self.voc = voc_mode;
}
pub fn to_json(&self) -> Result<String, serde_json::Error> {
serde_json::to_string(self)
}
}
pub struct ProgressTracker {
pub task_file: std::path::PathBuf,
pub current: std::sync::Arc<tokio::sync::Mutex<TaskProgress>>,
}
impl ProgressTracker {
pub fn new(task_file: impl AsRef<Path>, story_id: String, story_path: String) -> Self {
Self {
task_file: task_file.as_ref().to_path_buf(),
current: std::sync::Arc::new(tokio::sync::Mutex::new(TaskProgress::new(story_id, story_path))),
}
}
pub async fn update_task_status(&self, step: &str, progress_val: u32) -> anyhow::Result<()> {
let mut lock = self.current.lock().await;
lock.update_step(step.to_string(), progress_val);
let json = lock.to_json()?;
std::fs::write(&self.task_file, json)?;
Ok(())
}
pub async fn set_img_mode(&self, mode: &str) -> anyhow::Result<()> {
let mut lock = self.current.lock().await;
lock.set_img(mode.to_string());
let json = lock.to_json()?;
std::fs::write(&self.task_file, json)?;
Ok(())
}
pub async fn set_voc_mode(&self, mode: &str) -> anyhow::Result<()> {
let mut lock = self.current.lock().await;
lock.set_voc(mode.to_string());
let json = lock.to_json()?;
std::fs::write(&self.task_file, json)?;
Ok(())
}
pub async fn set_vfx_mode(&self, mode: &str) -> anyhow::Result<()> {
let mut lock = self.current.lock().await;
lock.vfx = mode.to_string();
let json = lock.to_json()?;
std::fs::write(&self.task_file, json)?;
Ok(())
}
pub fn write_progress(&self, progress: &TaskProgress) -> anyhow::Result<()> {
let json = serde_json::to_string(progress)?;
std::fs::write(&self.task_file, json)?;
Ok(())
}
pub fn clear(&self) -> anyhow::Result<()> {
if self.task_file.exists() {
let _ = std::fs::remove_file(&self.task_file);
}
Ok(())
}
pub fn read_progress(&self) -> anyhow::Result<TaskProgress> {
let json = std::fs::read_to_string(&self.task_file)?;
let progress = serde_json::from_str(&json)?;
Ok(progress)
}
}
/// Vérifier si la production est en pause (via fichier de contrôle)
pub fn check_pause(pause_file: &Path) -> bool {
pause_file.exists()
}
/// Attendre la fin de la pause (polling)
pub async fn wait_for_resume(pause_file: &Path) {
while pause_file.exists() {
tokio::time::sleep(std::time::Duration::from_secs(2)).await;
}
}
fn current_timestamp() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_progress_creation() {
let progress = TaskProgress::new("test_story".to_string(), "/path/to/story".to_string());
assert_eq!(progress.story_id, "test_story");
assert_eq!(progress.progress, 0);
assert_eq!(progress.img, "api");
assert_eq!(progress.voc, "api");
}
#[test]
fn test_progress_json_serialization() -> Result<(), serde_json::Error> {
let mut progress = TaskProgress::new("test".to_string(), "/path".to_string());
progress.set_img("local".to_string());
progress.set_voc("edge-tts".to_string());
let json_str = progress.to_json()?;
assert!(json_str.contains("test"));
assert!(json_str.contains("local"));
assert!(json_str.contains("edge-tts"));
Ok(())
}
#[tokio::test]
async fn test_progress_tracker_updates() -> anyhow::Result<()> {
let tmp_file = std::env::temp_dir().join("test_progress.json");
let tracker = ProgressTracker::new(&tmp_file, "story_test".to_string(), "/test".to_string());
tracker.update_task_status("In progress", 50).await?;
let read = tracker.read_progress()?;
assert_eq!(read.progress, 50);
assert_eq!(read.step, "In progress");
tracker.set_img_mode("ssd-1b").await?;
let read = tracker.read_progress()?;
assert_eq!(read.img, "ssd-1b");
tracker.set_voc_mode("elevenlabs").await?;
let read = tracker.read_progress()?;
assert_eq!(read.voc, "elevenlabs");
std::fs::remove_file(tmp_file)?;
Ok(())
}
}