argmin's picture
add files
510a9b0
def create_classification_prompt(row, label_descriptions, features, example_rows):
"""
Generates system and user prompts for classification.
Args:
row (dict): A single row of feature values.
label_descriptions (dict): Mapping of labels to their descriptions.
features (list): List of features to include in the prompt.
example_rows (list): Few-shot examples for the prompt.
Returns:
tuple: (system_prompt, user_prompt)
"""
# System prompt
system_prompt = "You are a classifier. Assign one of the following labels based on the input data:\n"
for label, desc in label_descriptions.items():
system_prompt += f"- {label}: {desc}\n"
# Few-shot examples
if example_rows:
system_prompt += "\nExamples:\n"
for example in example_rows:
example_features = "; ".join(
f"{feature}: {example['features'][feature]}" for feature in features
#f"{feature}: {example.get('features', {}).get(feature, 'MISSING')}" for feature in features
)
system_prompt += f"Input: {example_features}\nLabel: {example['label']}\n"
# User prompt for the current row
user_features = "; ".join(f"{feature}: {row[feature]}" for feature in features)
user_prompt = f"Input: {user_features}\nLabel:"
return system_prompt, user_prompt
def generate_prompts(row, label_descriptions, features, example_rows):
"""
Wrapper for create_classification_prompt to generate prompts for a row.
Args:
row (dict): Row of the dataset.
label_descriptions (dict): Mapping of labels to their descriptions.
features (list): List of features to include in the prompt.
example_rows (list): Few-shot examples for the prompt.
Returns:
tuple: (system_prompt, user_prompt)
"""
return create_classification_prompt(
row=row,
label_descriptions=label_descriptions,
features=features,
example_rows=example_rows,
)