|
|
""" |
|
|
Utility for computing F1 scores at file level for ranking generated outputs. |
|
|
This helps create preference pairs for DPO training. |
|
|
""" |
|
|
|
|
|
import json |
|
|
import re |
|
|
from typing import List, Set, Tuple, Dict |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
def extract_files_from_selection(output_text: str) -> Set[str]: |
|
|
""" |
|
|
Extract file paths from ##SELECT section. |
|
|
Expected format: modify::crates/path/to/file.rs::impl::ComponentName |
|
|
Returns set of unique file paths. |
|
|
""" |
|
|
files = set() |
|
|
|
|
|
|
|
|
select_match = re.search(r'##SELECT\s*(.*?)<EOS>', output_text, re.DOTALL | re.IGNORECASE) |
|
|
if not select_match: |
|
|
return files |
|
|
|
|
|
select_section = select_match.group(1) |
|
|
|
|
|
|
|
|
|
|
|
for line in select_section.strip().split('\n'): |
|
|
line = line.strip() |
|
|
if not line: |
|
|
continue |
|
|
|
|
|
|
|
|
parts = line.split('::') |
|
|
if len(parts) >= 2: |
|
|
file_path = parts[1] |
|
|
files.add(file_path) |
|
|
|
|
|
return files |
|
|
|
|
|
|
|
|
def compute_file_level_f1(predicted: str, ground_truth: str) -> Dict[str, float]: |
|
|
""" |
|
|
Compute F1 score based on file-level predictions. |
|
|
|
|
|
Args: |
|
|
predicted: Model output with ##SELECT section |
|
|
ground_truth: Ground truth output with ##SELECT section |
|
|
|
|
|
Returns: |
|
|
Dictionary with precision, recall, f1 scores |
|
|
""" |
|
|
pred_files = extract_files_from_selection(predicted) |
|
|
gt_files = extract_files_from_selection(ground_truth) |
|
|
|
|
|
if len(gt_files) == 0: |
|
|
|
|
|
if len(pred_files) == 0: |
|
|
return {"precision": 1.0, "recall": 1.0, "f1": 1.0} |
|
|
else: |
|
|
return {"precision": 0.0, "recall": 1.0, "f1": 0.0} |
|
|
|
|
|
if len(pred_files) == 0: |
|
|
|
|
|
return {"precision": 0.0, "recall": 0.0, "f1": 0.0} |
|
|
|
|
|
|
|
|
true_positives = len(pred_files & gt_files) |
|
|
false_positives = len(pred_files - gt_files) |
|
|
false_negatives = len(gt_files - pred_files) |
|
|
|
|
|
precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0.0 |
|
|
recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0.0 |
|
|
|
|
|
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0 |
|
|
|
|
|
return { |
|
|
"precision": precision, |
|
|
"recall": recall, |
|
|
"f1": f1, |
|
|
"true_positives": true_positives, |
|
|
"false_positives": false_positives, |
|
|
"false_negatives": false_negatives, |
|
|
"pred_files": list(pred_files), |
|
|
"gt_files": list(gt_files), |
|
|
} |
|
|
|
|
|
|
|
|
def rank_outputs_by_f1(outputs: List[str], ground_truth: str) -> List[Tuple[str, float, Dict]]: |
|
|
""" |
|
|
Rank multiple outputs by their F1 scores compared to ground truth. |
|
|
|
|
|
Args: |
|
|
outputs: List of model outputs to rank |
|
|
ground_truth: Ground truth output |
|
|
|
|
|
Returns: |
|
|
List of tuples: (output, f1_score, metrics_dict) sorted by F1 descending |
|
|
""" |
|
|
ranked = [] |
|
|
for output in outputs: |
|
|
metrics = compute_file_level_f1(output, ground_truth) |
|
|
ranked.append((output, metrics["f1"], metrics)) |
|
|
|
|
|
|
|
|
ranked.sort(key=lambda x: x[1], reverse=True) |
|
|
return ranked |
|
|
|
|
|
|
|
|
def create_dpo_pairs_from_generations( |
|
|
prompt: str, |
|
|
generations: List[str], |
|
|
ground_truth: str, |
|
|
min_f1_difference: float = 0.1 |
|
|
) -> List[Dict[str, str]]: |
|
|
""" |
|
|
Create DPO training pairs from multiple generations. |
|
|
Uses F1 score to determine which generation is better. |
|
|
|
|
|
Args: |
|
|
prompt: Input prompt/task |
|
|
generations: List of generated outputs |
|
|
ground_truth: Ground truth output |
|
|
min_f1_difference: Minimum F1 difference to create a pair |
|
|
|
|
|
Returns: |
|
|
List of DPO pairs: {"prompt": str, "chosen": str, "rejected": str} |
|
|
""" |
|
|
if len(generations) < 2: |
|
|
return [] |
|
|
|
|
|
ranked = rank_outputs_by_f1(generations, ground_truth) |
|
|
pairs = [] |
|
|
|
|
|
|
|
|
for i in range(len(ranked)): |
|
|
for j in range(i + 1, len(ranked)): |
|
|
better_output, better_f1, _ = ranked[i] |
|
|
worse_output, worse_f1, _ = ranked[j] |
|
|
|
|
|
|
|
|
if better_f1 - worse_f1 >= min_f1_difference: |
|
|
pairs.append({ |
|
|
"prompt": prompt, |
|
|
"chosen": better_output, |
|
|
"rejected": worse_output, |
|
|
"chosen_f1": better_f1, |
|
|
"rejected_f1": worse_f1, |
|
|
}) |
|
|
|
|
|
return pairs |
|
|
|
|
|
|
|
|
def convert_sft_to_dpo_with_sampling( |
|
|
sft_jsonl_path: str, |
|
|
output_jsonl_path: str, |
|
|
model_inference_fn, |
|
|
num_samples: int = 4, |
|
|
min_f1_difference: float = 0.1, |
|
|
temperature: float = 0.8 |
|
|
): |
|
|
""" |
|
|
Convert SFT dataset to DPO dataset by sampling multiple outputs and ranking by F1. |
|
|
|
|
|
Args: |
|
|
sft_jsonl_path: Path to SFT JSONL file |
|
|
output_jsonl_path: Path to output DPO JSONL file |
|
|
model_inference_fn: Function that takes (prompt, num_samples, temperature) and returns List[str] |
|
|
num_samples: Number of outputs to sample per prompt |
|
|
min_f1_difference: Minimum F1 difference to create a pair |
|
|
temperature: Sampling temperature |
|
|
""" |
|
|
pairs_created = 0 |
|
|
|
|
|
with open(sft_jsonl_path, 'r') as f_in, open(output_jsonl_path, 'w') as f_out: |
|
|
for line in f_in: |
|
|
data = json.loads(line) |
|
|
|
|
|
|
|
|
prompt = data.get("input", "") |
|
|
ground_truth = data.get("output", "") |
|
|
|
|
|
if not prompt or not ground_truth: |
|
|
continue |
|
|
|
|
|
|
|
|
try: |
|
|
generations = model_inference_fn(prompt, num_samples, temperature) |
|
|
except Exception as e: |
|
|
print(f"Error generating outputs: {e}") |
|
|
continue |
|
|
|
|
|
|
|
|
pairs = create_dpo_pairs_from_generations( |
|
|
prompt, generations, ground_truth, min_f1_difference |
|
|
) |
|
|
|
|
|
|
|
|
for pair in pairs: |
|
|
f_out.write(json.dumps(pair) + '\n') |
|
|
pairs_created += 1 |
|
|
|
|
|
print(f"Created {pairs_created} DPO pairs from {sft_jsonl_path}") |
|
|
|
|
|
|
|
|
def prepare_dpo_data_from_instruct( |
|
|
instruct_jsonl: str, |
|
|
output_dpo_jsonl: str, |
|
|
): |
|
|
""" |
|
|
Simple conversion from instruction data to DPO format. |
|
|
This assumes you already have multiple outputs per input or will generate them. |
|
|
|
|
|
For demonstration, this creates a basic structure. In practice, you need to: |
|
|
1. Generate multiple outputs for each input |
|
|
2. Rank them by F1 score |
|
|
3. Create chosen/rejected pairs |
|
|
""" |
|
|
print(f"Converting {instruct_jsonl} to DPO format...") |
|
|
print("Note: This requires generating multiple outputs per prompt.") |
|
|
print("Use convert_sft_to_dpo_with_sampling() with your model for actual conversion.") |
|
|
|
|
|
|
|
|
with open(instruct_jsonl, 'r') as f: |
|
|
for line in f: |
|
|
data = json.loads(line) |
|
|
print(f"Input: {data.get('input', '')[:100]}...") |
|
|
print(f"Ground truth output available: {len(data.get('output', ''))} chars") |
|
|
print(" -> Need to generate multiple outputs and rank by F1 score") |
|
|
print() |
|
|
break |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
print("F1 Score Utility for File-Level Ranking") |
|
|
print("=" * 50) |
|
|
|
|
|
|
|
|
ground_truth = """ |
|
|
##OUTPUT |
|
|
The webhook system requires subscription support. |
|
|
##SELECT |
|
|
crates/common_enums/src/enums.rs::EventClass |
|
|
crates/router/src/webhooks.rs::process_webhook |
|
|
<EOS> |
|
|
""" |
|
|
|
|
|
prediction1 = """ |
|
|
##OUTPUT |
|
|
The webhook system requires subscription support. |
|
|
##SELECT |
|
|
crates/common_enums/src/enums.rs::EventClass |
|
|
crates/router/src/webhooks.rs::process_webhook |
|
|
<EOS> |
|
|
""" |
|
|
|
|
|
prediction2 = """ |
|
|
##OUTPUT |
|
|
The webhook system requires subscription support. |
|
|
##SELECT |
|
|
crates/common_enums/src/enums.rs::EventClass |
|
|
crates/router/src/handlers.rs::handle_request |
|
|
<EOS> |
|
|
""" |
|
|
|
|
|
print("\nExample 1: Perfect match") |
|
|
metrics1 = compute_file_level_f1(prediction1, ground_truth) |
|
|
print(f"F1 Score: {metrics1['f1']:.3f}") |
|
|
print(f"Precision: {metrics1['precision']:.3f}, Recall: {metrics1['recall']:.3f}") |
|
|
|
|
|
print("\nExample 2: Partial match") |
|
|
metrics2 = compute_file_level_f1(prediction2, ground_truth) |
|
|
print(f"F1 Score: {metrics2['f1']:.3f}") |
|
|
print(f"Precision: {metrics2['precision']:.3f}, Recall: {metrics2['recall']:.3f}") |
|
|
|
|
|
print("\nExample 3: Ranking outputs") |
|
|
outputs = [prediction1, prediction2] |
|
|
ranked = rank_outputs_by_f1(outputs, ground_truth) |
|
|
print("Ranked outputs:") |
|
|
for i, (output, f1, metrics) in enumerate(ranked, 1): |
|
|
print(f" {i}. F1={f1:.3f} - {metrics['true_positives']} correct files") |
|
|
|