File size: 3,698 Bytes
685d968
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import unittest
import json
import os
from synthetic_data.pipeline import SyntheticDataPipeline
from unittest.mock import MagicMock, patch

class TestSyntheticDataPipeline(unittest.TestCase):
    
    @patch('synthetic_data.pipeline.cohere.ClientV2')
    def test_pipeline_structure(self, mock_client):
        # Setup mock response
        mock_instance = mock_client.return_value
        
        # Mock scenario generation response
        scenario_block = MagicMock()
        scenario_block.type = "text"
        scenario_block.text = json.dumps({
            "scenario_description": "Test scenario",
            "user_profile": "Test user",
            "key_signals_to_include": ["signal1"],
            "distractor_signals": [],
            "suggested_turn_breakdown": "Test breakdown"
        })
        mock_scenario_response = MagicMock()
        mock_scenario_response.message.content = [scenario_block]
        
        # Mock conversation generation response
        conv_block = MagicMock()
        conv_block.type = "text"
        conv_block.text = json.dumps({
            "scenario_id": "test_id_001",
            "conversation": [
                {"role": "user", "content": "test"},
                {"role": "assistant", "content": "response"}
            ],
            "labels": {
                "categories": ["company.brand_core"],
                "persistence_horizon": "long",
                "memory_scope": "company",
                "rationale": "test rationale"
            },
            "metadata": {
                "scenario_type": "test",
                "primary_category": "company.brand_core",
                "distractor_present": False,
                "turn_count": 2,
                "signals_present": ["signal1"]
            }
        })
        mock_conv_response = MagicMock()
        mock_conv_response.message.content = [conv_block]
        
        mock_instance.chat.side_effect = [mock_scenario_response, mock_conv_response]
        
        # Run pipeline with dummy key
        pipeline = SyntheticDataPipeline(api_key="dummy_key")
        results = pipeline.run_batch(count=1)
        
        self.assertEqual(len(results), 1)
        self.assertIn('scenario_id', results[0])
        self.assertIn('conversation', results[0])
        self.assertIn('labels', results[0])
        self.assertEqual(results[0]['labels']['categories'], ["company.brand_core"])

    def test_conversation_structure(self):
        # Validate structure of a sample output
        sample_data = {
            "scenario_id": "brand_core_test_001",
            "conversation": [
                {"role": "user", "content": "test"},
                {"role": "assistant", "content": "response"}
            ],
            "labels": {
                "categories": ["company.brand_core"],
                "persistence_horizon": "long",
                "memory_scope": "company",
                "rationale": "test"
            },
            "metadata": {
                "scenario_type": "test",
                "primary_category": "company.brand_core",
                "distractor_present": False,
                "turn_count": 2,
                "signals_present": []
            }
        }
        
        # Check required fields
        required_fields = ['scenario_id', 'conversation', 'labels', 'metadata']
        for field in required_fields:
            self.assertIn(field, sample_data)
            
        # Check label structure
        label_fields = ['categories', 'persistence_horizon', 'memory_scope', 'rationale']
        for field in label_fields:
            self.assertIn(field, sample_data['labels'])

if __name__ == '__main__':
    unittest.main()