MuratcanKoylan's picture
Upload folder using huggingface_hub
685d968 verified
import unittest
import json
import os
from synthetic_data.pipeline import SyntheticDataPipeline
from unittest.mock import MagicMock, patch
class TestSyntheticDataPipeline(unittest.TestCase):
@patch('synthetic_data.pipeline.cohere.ClientV2')
def test_pipeline_structure(self, mock_client):
# Setup mock response
mock_instance = mock_client.return_value
# Mock scenario generation response
scenario_block = MagicMock()
scenario_block.type = "text"
scenario_block.text = json.dumps({
"scenario_description": "Test scenario",
"user_profile": "Test user",
"key_signals_to_include": ["signal1"],
"distractor_signals": [],
"suggested_turn_breakdown": "Test breakdown"
})
mock_scenario_response = MagicMock()
mock_scenario_response.message.content = [scenario_block]
# Mock conversation generation response
conv_block = MagicMock()
conv_block.type = "text"
conv_block.text = json.dumps({
"scenario_id": "test_id_001",
"conversation": [
{"role": "user", "content": "test"},
{"role": "assistant", "content": "response"}
],
"labels": {
"categories": ["company.brand_core"],
"persistence_horizon": "long",
"memory_scope": "company",
"rationale": "test rationale"
},
"metadata": {
"scenario_type": "test",
"primary_category": "company.brand_core",
"distractor_present": False,
"turn_count": 2,
"signals_present": ["signal1"]
}
})
mock_conv_response = MagicMock()
mock_conv_response.message.content = [conv_block]
mock_instance.chat.side_effect = [mock_scenario_response, mock_conv_response]
# Run pipeline with dummy key
pipeline = SyntheticDataPipeline(api_key="dummy_key")
results = pipeline.run_batch(count=1)
self.assertEqual(len(results), 1)
self.assertIn('scenario_id', results[0])
self.assertIn('conversation', results[0])
self.assertIn('labels', results[0])
self.assertEqual(results[0]['labels']['categories'], ["company.brand_core"])
def test_conversation_structure(self):
# Validate structure of a sample output
sample_data = {
"scenario_id": "brand_core_test_001",
"conversation": [
{"role": "user", "content": "test"},
{"role": "assistant", "content": "response"}
],
"labels": {
"categories": ["company.brand_core"],
"persistence_horizon": "long",
"memory_scope": "company",
"rationale": "test"
},
"metadata": {
"scenario_type": "test",
"primary_category": "company.brand_core",
"distractor_present": False,
"turn_count": 2,
"signals_present": []
}
}
# Check required fields
required_fields = ['scenario_id', 'conversation', 'labels', 'metadata']
for field in required_fields:
self.assertIn(field, sample_data)
# Check label structure
label_fields = ['categories', 'persistence_horizon', 'memory_scope', 'rationale']
for field in label_fields:
self.assertIn(field, sample_data['labels'])
if __name__ == '__main__':
unittest.main()