|
|
import unittest |
|
|
import json |
|
|
import os |
|
|
from synthetic_data.pipeline import SyntheticDataPipeline |
|
|
from unittest.mock import MagicMock, patch |
|
|
|
|
|
class TestSyntheticDataPipeline(unittest.TestCase): |
|
|
|
|
|
@patch('synthetic_data.pipeline.cohere.ClientV2') |
|
|
def test_pipeline_structure(self, mock_client): |
|
|
|
|
|
mock_instance = mock_client.return_value |
|
|
|
|
|
|
|
|
scenario_block = MagicMock() |
|
|
scenario_block.type = "text" |
|
|
scenario_block.text = json.dumps({ |
|
|
"scenario_description": "Test scenario", |
|
|
"user_profile": "Test user", |
|
|
"key_signals_to_include": ["signal1"], |
|
|
"distractor_signals": [], |
|
|
"suggested_turn_breakdown": "Test breakdown" |
|
|
}) |
|
|
mock_scenario_response = MagicMock() |
|
|
mock_scenario_response.message.content = [scenario_block] |
|
|
|
|
|
|
|
|
conv_block = MagicMock() |
|
|
conv_block.type = "text" |
|
|
conv_block.text = json.dumps({ |
|
|
"scenario_id": "test_id_001", |
|
|
"conversation": [ |
|
|
{"role": "user", "content": "test"}, |
|
|
{"role": "assistant", "content": "response"} |
|
|
], |
|
|
"labels": { |
|
|
"categories": ["company.brand_core"], |
|
|
"persistence_horizon": "long", |
|
|
"memory_scope": "company", |
|
|
"rationale": "test rationale" |
|
|
}, |
|
|
"metadata": { |
|
|
"scenario_type": "test", |
|
|
"primary_category": "company.brand_core", |
|
|
"distractor_present": False, |
|
|
"turn_count": 2, |
|
|
"signals_present": ["signal1"] |
|
|
} |
|
|
}) |
|
|
mock_conv_response = MagicMock() |
|
|
mock_conv_response.message.content = [conv_block] |
|
|
|
|
|
mock_instance.chat.side_effect = [mock_scenario_response, mock_conv_response] |
|
|
|
|
|
|
|
|
pipeline = SyntheticDataPipeline(api_key="dummy_key") |
|
|
results = pipeline.run_batch(count=1) |
|
|
|
|
|
self.assertEqual(len(results), 1) |
|
|
self.assertIn('scenario_id', results[0]) |
|
|
self.assertIn('conversation', results[0]) |
|
|
self.assertIn('labels', results[0]) |
|
|
self.assertEqual(results[0]['labels']['categories'], ["company.brand_core"]) |
|
|
|
|
|
def test_conversation_structure(self): |
|
|
|
|
|
sample_data = { |
|
|
"scenario_id": "brand_core_test_001", |
|
|
"conversation": [ |
|
|
{"role": "user", "content": "test"}, |
|
|
{"role": "assistant", "content": "response"} |
|
|
], |
|
|
"labels": { |
|
|
"categories": ["company.brand_core"], |
|
|
"persistence_horizon": "long", |
|
|
"memory_scope": "company", |
|
|
"rationale": "test" |
|
|
}, |
|
|
"metadata": { |
|
|
"scenario_type": "test", |
|
|
"primary_category": "company.brand_core", |
|
|
"distractor_present": False, |
|
|
"turn_count": 2, |
|
|
"signals_present": [] |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
required_fields = ['scenario_id', 'conversation', 'labels', 'metadata'] |
|
|
for field in required_fields: |
|
|
self.assertIn(field, sample_data) |
|
|
|
|
|
|
|
|
label_fields = ['categories', 'persistence_horizon', 'memory_scope', 'rationale'] |
|
|
for field in label_fields: |
|
|
self.assertIn(field, sample_data['labels']) |
|
|
|
|
|
if __name__ == '__main__': |
|
|
unittest.main() |
|
|
|
|
|
|