File size: 2,291 Bytes
685d968
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
"""Quick test of balanced generation for underrepresented categories."""

import json
import os
from dotenv import load_dotenv
load_dotenv()

import cohere

client = cohere.ClientV2(api_key=os.getenv("COHERE_API_KEY"))

# Test with underrepresented categories
test_categories = ["company.tools_config", "company.knowledge_artifacts", "none"]

for category in test_categories:
    print(f"\n{'='*60}")
    print(f"Testing: {category}")
    print("="*60)
    
    if category == "none":
        prompt = """Generate a marketing conversation that has NO long-term memory value.

The conversation should be transactional, vague, or temporary.
Examples: checking status, scheduling, confirming receipt.

Generate 4 turns. Start mid-conversation (no greetings).

OUTPUT FORMAT (JSON only):
{
  "scenario_id": "none_001",
  "conversation": [
    {"role": "user", "content": "..."},
    {"role": "assistant", "content": "..."}
  ],
  "labels": {
    "categories": ["none"],
    "rationale": "..."
  }
}"""
    else:
        prompt = f"""Generate a marketing conversation that clearly demonstrates: {category}

The conversation MUST contain clear signals for this category.
4-6 turns, start mid-conversation (no greetings).

CRITICAL: The categories array MUST include "{category}".

OUTPUT FORMAT (JSON only):
{{
  "scenario_id": "{category.replace('.', '_')}_001",
  "conversation": [
    {{"role": "user", "content": "..."}},
    {{"role": "assistant", "content": "..."}}
  ],
  "labels": {{
    "categories": ["{category}"],
    "rationale": "..."
  }}
}}"""

    try:
        response = client.chat(
            messages=[{"role": "user", "content": prompt}],
            temperature=0.7,
            model="command-r-plus-08-2024",
            response_format={"type": "json_object"}
        )
        
        content = response.message.content[0].text
        data = json.loads(content)
        
        output_cats = data.get("labels", {}).get("categories", [])
        print(f"Target: {category}")
        print(f"Output: {output_cats}")
        print(f"Match: {'YES' if category in output_cats else 'NO'}")
        
        if data.get("conversation"):
            print(f"First turn: {data['conversation'][0]['content'][:80]}...")
    except Exception as e:
        print(f"Error: {e}")