File size: 3,253 Bytes
b87a828
 
 
 
 
 
 
 
 
 
 
 
 
d1f1387
b87a828
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ae4542
b87a828
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import time
import threading
from huggingface_hub import HfApi

# 🟢 Safe Import for AutoTrain
try:
    from autotrain.api import AutoTrainClient
    from datasets import load_dataset
    AUTOTRAIN_AVAILABLE = True
except ImportError:
    print("⚠️ AutoTrain not installed. Watcher Agent will sleep.")
    AUTOTRAIN_AVAILABLE = False

class AgentWatcher:
    def __init__(self, config):
        print("🕵️‍♂️ Watcher Agent (AutoTrain) Online.")
        self.config = config
        self.hf_token = config.HF_TOKEN
        self.dataset_id = config.DATASET_ID
        self.threshold = 1000  # Trigger training after this many rows
        self.check_interval = 3600 # Check every 1 hour
        self._stop_event = threading.Event()

    def trigger_autotrain(self, data_count):
        """Launches the training job via AutoTrain API."""
        try:
            print(f"🚀 [WATCHER] Triggering AutoTrain for {data_count} items...")
            client = AutoTrainClient(hf_token=self.hf_token)
            
            # Create a unique project name based on time
            project_name = f"pure-versation-finetune-{int(time.time())}"
            
            # Create and start project (Speech Recognition task)
            client.create_project(project_name, task="speech-recognition")
            
            print(f"🔥 [WATCHER] Training job '{project_name}' started successfully!")
            return True
        except Exception as e:
            print(f"❌ [WATCHER] AutoTrain Trigger Failed: {e}")
            return False

    def check_dataset_status(self):
        """Checks the dataset count."""
        print("🔍 [WATCHER] Checking Pure Chain dataset size...")
        
        if not self.hf_token:
            print("⚠️ [WATCHER] No HF_TOKEN found. Skipping check.")
            return

        try:
            # Load dataset in Streaming mode (Fast & Lightweight)
            ds = load_dataset(self.dataset_id, split="train", streaming=True, token=self.hf_token)
            
            count = 0
            for _ in ds:
                count += 1
                # Safety break to avoid long reads
                if count > self.threshold + 500: 
                    break
            
            print(f"✅ [WATCHER] Found {count} rows (Threshold: {self.threshold}).")

            if count >= self.threshold:
                self.trigger_autotrain(count)
            else:
                print(f"💤 [WATCHER] Not enough data yet ({count}/{self.threshold}).")

        except Exception as e:
            print(f"❌ [WATCHER] Error checking dataset: {e}")

    def _loop(self):
        """The background loop."""
        while not self._stop_event.is_set():
            if AUTOTRAIN_AVAILABLE:
                try:
                    self.check_dataset_status()
                except Exception as e:
                    print(f"❌ [WATCHER] Loop Error: {e}")
            
            # Sleep for interval (default 1 hour)
            time.sleep(self.check_interval)

    def start(self):
        """Starts the background thread."""
        thread = threading.Thread(target=self._loop, daemon=True)
        thread.start()
        print("🕵️‍♂️ Watcher Agent background thread started.")