| from __future__ import annotations |
|
|
| import json |
| import random |
| from collections.abc import Iterable, Sequence |
| from typing import Any |
|
|
| from .prompts import AUTHOR_SYSTEM |
| from .quality import is_situation_grounded |
| from .safety import guard_input |
| from .schema import ForestDraft |
|
|
| CREATURES = ( |
| ("The Patient Fox", "patience under pressure", "I am allowed to learn."), |
| ("The Listening Owl", "careful perspective", "I can listen before I leap."), |
| ("The Brave Snail", "quiet courage", "I make progress at my pace."), |
| ("The Steady Deer", "steadiness through change", "I can meet one moment."), |
| ("The Clear-Voiced Wren", "honest self-expression", "I can speak gently and clearly."), |
| ("The Curious Otter", "playful curiosity", "I can stay curious here."), |
| ("The Gentle Bear", "self-compassion", "I can be kind to myself."), |
| ("The Building Beaver", "practical resourcefulness", "I can shape one next step."), |
| ) |
|
|
| IMAGE_PROMPTS = { |
| "The Patient Fox": "a gentle russet fox sitting in a mossy clearing, soft kind eyes", |
| "The Listening Owl": "a round tawny owl resting on a low branch, attentive kind eyes", |
| "The Brave Snail": "a tiny snail crossing a fern frond, softly glowing spiral shell", |
| "The Steady Deer": "a small deer standing in morning mist, calm gentle expression", |
| "The Clear-Voiced Wren": "a tiny wren singing beside dusty rose wildflowers", |
| "The Curious Otter": "a small otter floating beside reeds, bright gentle eyes", |
| "The Gentle Bear": "a sleepy bear cub resting under dappled canopy light", |
| "The Building Beaver": "a friendly beaver holding one smooth twig beside a stream", |
| } |
|
|
| TEMPLATE_NAMES = ( |
| "Mika", |
| "Noor", |
| "Ari", |
| "Sam", |
| "Leila", |
| "Jun", |
| "Robin", |
| "Maya", |
| "Theo", |
| "Nia", |
| "Alex", |
| "Bao", |
| "Clara", |
| "Dev", |
| "Esme", |
| "Farah", |
| "Gabe", |
| "Hana", |
| "Iris", |
| "Jules", |
| "Kai", |
| "Lina", |
| "Mateo", |
| "Niko", |
| "Omar", |
| "Priya", |
| "Quinn", |
| "Ravi", |
| "Sora", |
| "Tess", |
| "Uma", |
| "Vera", |
| "Will", |
| "Xia", |
| "Yara", |
| "Zane", |
| "An", |
| "Bea", |
| "Cole", |
| "Dara", |
| ) |
|
|
| TEACHER_NAMES = ( |
| "Cora", |
| "Eli", |
| "Imani", |
| "Pax", |
| "Ren", |
| "Sol", |
| "Tala", |
| "Zoe", |
| ) |
|
|
| LINE_TEMPLATES = ( |
| ( |
| "{situation} can be difficult without meaning you are failing. " |
| "Your {strength} can help you choose one honest next step." |
| ), |
| ( |
| "There is real uncertainty in {situation}. {strength_capitalized} does not erase it; " |
| "it gives you room to respond without rushing." |
| ), |
| ( |
| "When {situation_lower}, you do not have to solve the whole path today. " |
| "Your {strength} is enough for the next part." |
| ), |
| ( |
| "{situation} asks a lot of you. Let {strength} make the next choice smaller, " |
| "clearer, and kinder." |
| ), |
| ( |
| "It makes sense that {situation_lower} feels tender. Your {strength} can sit " |
| "beside the difficulty instead of pretending it is easy." |
| ), |
| ( |
| "You can acknowledge the hard part of {situation_lower} and still trust your " |
| "{strength} to help with what comes next." |
| ), |
| ) |
|
|
| REFLECTION_TEMPLATES = ( |
| "What would change if you asked only for the next kind and concrete step?", |
| "Which part deserves your attention now, and which part can wait?", |
| "What could you notice before deciding you must already have the answer?", |
| "What is one choice that would make this moment more workable?", |
| "How might you make room for both the difficulty and your own agency?", |
| "What support would let you move with care instead of pressure?", |
| ) |
|
|
|
|
| def normalize_positive_frame(row: dict[str, Any]) -> dict[str, str]: |
| return { |
| "situation": str(row.get("original_text", "")).strip(), |
| "support_hint": str(row.get("reframed_text", "")).strip(), |
| "strategy": str(row.get("strategy", "")).strip(), |
| "source": "SALT-NLP/positive_reframing", |
| } |
|
|
|
|
| def normalize_empathy(row: dict[str, Any]) -> dict[str, str]: |
| conversations = row.get("conversations") or [] |
| assistant_lines = [ |
| str(message.get("content", "")).strip() |
| for message in conversations |
| if message.get("role") == "assistant" and message.get("content") |
| ] |
| return { |
| "situation": str(row.get("situation", "")).strip(), |
| "support_hint": " ".join(assistant_lines[:2]), |
| "strategy": str(row.get("emotion", "")).strip(), |
| "source": "Estwld/empathetic_dialogues_llm", |
| } |
|
|
|
|
| def validate_synthetic_example(example: dict[str, Any]) -> dict[str, Any] | None: |
| name = str(example.get("name", "")).strip() |
| situation = str(example.get("situation", "")).strip() |
| if not guard_input(name, situation).allowed: |
| return None |
| try: |
| forest = ForestDraft.model_validate(example.get("forest")) |
| except (TypeError, ValueError): |
| return None |
| if any(not is_situation_grounded(clearing.line, situation) for clearing in forest.clearings): |
| return None |
| return { |
| "name": name, |
| "situation": situation, |
| "forest": forest.model_dump(), |
| "source": str(example.get("source", "synthetic")), |
| "teacher_model": str(example.get("teacher_model", "")), |
| } |
|
|
|
|
| def build_sft_record(example: dict[str, Any]) -> dict[str, Any]: |
| user_content = json.dumps( |
| {"name": example["name"], "situation": example["situation"]}, |
| ensure_ascii=False, |
| ) |
| assistant_content = json.dumps(example["forest"], ensure_ascii=False) |
| return { |
| "name": example["name"], |
| "situation": example["situation"], |
| "source": example.get("source", "synthetic"), |
| "teacher_model": example.get("teacher_model", ""), |
| "messages": [ |
| {"role": "system", "content": AUTHOR_SYSTEM}, |
| {"role": "user", "content": user_content}, |
| {"role": "assistant", "content": assistant_content}, |
| ], |
| } |
|
|
|
|
| def deduplicate_records(records: Iterable[dict[str, Any]]) -> list[dict[str, Any]]: |
| seen: set[tuple[str, str]] = set() |
| result: list[dict[str, Any]] = [] |
| for record in records: |
| identity = ( |
| str(record["name"]).strip().casefold(), |
| str(record["situation"]).strip().casefold(), |
| ) |
| if identity in seen: |
| continue |
| seen.add(identity) |
| result.append(record) |
| return result |
|
|
|
|
| def split_records( |
| records: Sequence[dict[str, Any]], |
| *, |
| validation_fraction: float = 0.1, |
| seed: int = 42, |
| ) -> dict[str, list[dict[str, Any]]]: |
| if not 0 < validation_fraction < 1: |
| raise ValueError("validation_fraction must be between zero and one") |
| shuffled = list(records) |
| random.Random(seed).shuffle(shuffled) |
| validation_count = max(1, round(len(shuffled) * validation_fraction)) |
| return { |
| "train": shuffled[validation_count:], |
| "validation": shuffled[:validation_count], |
| } |
|
|
|
|
| def template_forest(name: str, situation: str, variant: int) -> dict[str, Any]: |
| rotated = list(CREATURES[variant % len(CREATURES) :]) + list( |
| CREATURES[: variant % len(CREATURES)] |
| ) |
| selected = rotated[:5] |
| clearings = [] |
| for clearing_index, (creature, strength, spell) in enumerate(selected): |
| line_template = LINE_TEMPLATES[(variant + clearing_index) % len(LINE_TEMPLATES)] |
| clearings.append( |
| { |
| "creature": creature, |
| "strength": strength, |
| "line": line_template.format( |
| situation=situation.rstrip("."), |
| situation_lower=situation.rstrip(".").lower(), |
| strength=strength, |
| strength_capitalized=strength.capitalize(), |
| ), |
| "reflection": REFLECTION_TEMPLATES[ |
| (variant + clearing_index) % len(REFLECTION_TEMPLATES) |
| ], |
| "spell": spell, |
| "image_prompt": IMAGE_PROMPTS[creature], |
| } |
| ) |
| return { |
| "forest_title": ( |
| f"{name}'s Path Through This Moment" |
| if variant % 2 == 0 |
| else f"{name}'s Clearing for What Comes Next" |
| ), |
| "proposed_strengths": [item[1] for item in selected], |
| "clearings": clearings, |
| } |
|
|
|
|
| def build_template_examples( |
| situations: Sequence[str], |
| *, |
| variants_per_situation: int = 4, |
| ) -> list[dict[str, Any]]: |
| examples = [] |
| for situation_index, situation in enumerate(situations): |
| for variant in range(variants_per_situation): |
| name = TEMPLATE_NAMES[(situation_index + variant) % len(TEMPLATE_NAMES)] |
| examples.append( |
| { |
| "name": name, |
| "situation": situation, |
| "forest": template_forest(name, situation, variant), |
| "source": "template_coverage", |
| } |
| ) |
| return examples |
|
|
|
|
| def teacher_requests(situations: Sequence[str], *, start: int) -> list[dict[str, str]]: |
| return [ |
| { |
| "name": TEACHER_NAMES[(start + offset) % len(TEACHER_NAMES)], |
| "situation": situation, |
| } |
| for offset, situation in enumerate(situations) |
| ] |
|
|
|
|
| def forest_batch_json_schema() -> dict[str, Any]: |
| clearing = { |
| "type": "object", |
| "additionalProperties": False, |
| "required": [ |
| "creature", |
| "strength", |
| "line", |
| "reflection", |
| "spell", |
| "image_prompt", |
| ], |
| "properties": { |
| "creature": {"type": "string"}, |
| "strength": {"type": "string"}, |
| "line": {"type": "string"}, |
| "reflection": {"type": "string"}, |
| "spell": {"type": "string"}, |
| "image_prompt": {"type": "string"}, |
| }, |
| } |
| forest = { |
| "type": "object", |
| "additionalProperties": False, |
| "required": ["forest_title", "proposed_strengths", "clearings"], |
| "properties": { |
| "forest_title": {"type": "string"}, |
| "proposed_strengths": { |
| "type": "array", |
| "items": {"type": "string"}, |
| }, |
| "clearings": { |
| "type": "array", |
| "items": clearing, |
| }, |
| }, |
| } |
| return { |
| "type": "object", |
| "additionalProperties": False, |
| "required": ["examples"], |
| "properties": { |
| "examples": { |
| "type": "array", |
| "items": { |
| "type": "object", |
| "additionalProperties": False, |
| "required": ["name", "situation", "forest"], |
| "properties": { |
| "name": {"type": "string"}, |
| "situation": {"type": "string"}, |
| "forest": forest, |
| }, |
| }, |
| } |
| }, |
| } |
|
|
|
|
| class CohereForestGenerator: |
| def __init__( |
| self, |
| api_key: str, |
| *, |
| model: str = "command-a-03-2025", |
| ) -> None: |
| import cohere |
|
|
| self.client = cohere.ClientV2(api_key=api_key) |
| self.model = model |
|
|
| def generate( |
| self, |
| requests: Sequence[dict[str, str]], |
| *, |
| source_hints: Sequence[dict[str, str]] = (), |
| seed: int = 42, |
| ) -> list[dict[str, Any]]: |
| prompt = { |
| "task": ( |
| "Write one complete Compliment Forest for each request. Return 4-6 distinct " |
| "clearings. Every line must repeat at least one concrete noun or phrase from " |
| "its situation. Acknowledge difficulty without diagnosis, guarantees, hollow " |
| "praise, or toxic positivity. Spells begin with 'I' and use at most 12 words. " |
| "Image prompts describe one creature only and contain no style words." |
| ), |
| "requests": list(requests), |
| "voice_hints": list(source_hints)[:8], |
| } |
| response = self.client.chat( |
| model=self.model, |
| messages=[ |
| {"role": "system", "content": AUTHOR_SYSTEM}, |
| {"role": "user", "content": json.dumps(prompt, ensure_ascii=False)}, |
| ], |
| response_format={ |
| "type": "json_object", |
| "json_schema": forest_batch_json_schema(), |
| }, |
| safety_mode="CONTEXTUAL", |
| temperature=0.65, |
| max_tokens=5000, |
| seed=seed, |
| ) |
| payload = json.loads(response.message.content[0].text) |
| return list(payload.get("examples", [])) |
|
|