| """ |
| Utility for computing F1 scores at file level for ranking generated outputs. |
| This helps create preference pairs for DPO training. |
| """ |
|
|
| import json |
| import re |
| from typing import List, Set, Tuple, Dict |
| from pathlib import Path |
|
|
|
|
| def extract_files_from_selection(output_text: str) -> Set[str]: |
| """ |
| Extract file paths from ##SELECT section. |
| Expected format: modify::crates/path/to/file.rs::impl::ComponentName |
| Returns set of unique file paths. |
| """ |
| files = set() |
| |
| |
| select_match = re.search(r'##SELECT\s*(.*?)<EOS>', output_text, re.DOTALL | re.IGNORECASE) |
| if not select_match: |
| return files |
| |
| select_section = select_match.group(1) |
| |
| |
| |
| for line in select_section.strip().split('\n'): |
| line = line.strip() |
| if not line: |
| continue |
| |
| |
| parts = line.split('::') |
| if len(parts) >= 2: |
| file_path = parts[1] |
| files.add(file_path) |
| |
| return files |
|
|
|
|
| def compute_file_level_f1(predicted: str, ground_truth: str) -> Dict[str, float]: |
| """ |
| Compute F1 score based on file-level predictions. |
| |
| Args: |
| predicted: Model output with ##SELECT section |
| ground_truth: Ground truth output with ##SELECT section |
| |
| Returns: |
| Dictionary with precision, recall, f1 scores |
| """ |
| pred_files = extract_files_from_selection(predicted) |
| gt_files = extract_files_from_selection(ground_truth) |
| |
| if len(gt_files) == 0: |
| |
| if len(pred_files) == 0: |
| return {"precision": 1.0, "recall": 1.0, "f1": 1.0} |
| else: |
| return {"precision": 0.0, "recall": 1.0, "f1": 0.0} |
| |
| if len(pred_files) == 0: |
| |
| return {"precision": 0.0, "recall": 0.0, "f1": 0.0} |
| |
| |
| true_positives = len(pred_files & gt_files) |
| false_positives = len(pred_files - gt_files) |
| false_negatives = len(gt_files - pred_files) |
| |
| precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0.0 |
| recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0.0 |
| |
| f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0 |
| |
| return { |
| "precision": precision, |
| "recall": recall, |
| "f1": f1, |
| "true_positives": true_positives, |
| "false_positives": false_positives, |
| "false_negatives": false_negatives, |
| "pred_files": list(pred_files), |
| "gt_files": list(gt_files), |
| } |
|
|
|
|
| def rank_outputs_by_f1(outputs: List[str], ground_truth: str) -> List[Tuple[str, float, Dict]]: |
| """ |
| Rank multiple outputs by their F1 scores compared to ground truth. |
| |
| Args: |
| outputs: List of model outputs to rank |
| ground_truth: Ground truth output |
| |
| Returns: |
| List of tuples: (output, f1_score, metrics_dict) sorted by F1 descending |
| """ |
| ranked = [] |
| for output in outputs: |
| metrics = compute_file_level_f1(output, ground_truth) |
| ranked.append((output, metrics["f1"], metrics)) |
| |
| |
| ranked.sort(key=lambda x: x[1], reverse=True) |
| return ranked |
|
|
|
|
| def create_dpo_pairs_from_generations( |
| prompt: str, |
| generations: List[str], |
| ground_truth: str, |
| min_f1_difference: float = 0.1 |
| ) -> List[Dict[str, str]]: |
| """ |
| Create DPO training pairs from multiple generations. |
| Uses F1 score to determine which generation is better. |
| |
| Args: |
| prompt: Input prompt/task |
| generations: List of generated outputs |
| ground_truth: Ground truth output |
| min_f1_difference: Minimum F1 difference to create a pair |
| |
| Returns: |
| List of DPO pairs: {"prompt": str, "chosen": str, "rejected": str} |
| """ |
| if len(generations) < 2: |
| return [] |
| |
| ranked = rank_outputs_by_f1(generations, ground_truth) |
| pairs = [] |
| |
| |
| for i in range(len(ranked)): |
| for j in range(i + 1, len(ranked)): |
| better_output, better_f1, _ = ranked[i] |
| worse_output, worse_f1, _ = ranked[j] |
| |
| |
| if better_f1 - worse_f1 >= min_f1_difference: |
| pairs.append({ |
| "prompt": prompt, |
| "chosen": better_output, |
| "rejected": worse_output, |
| "chosen_f1": better_f1, |
| "rejected_f1": worse_f1, |
| }) |
| |
| return pairs |
|
|
|
|
| def convert_sft_to_dpo_with_sampling( |
| sft_jsonl_path: str, |
| output_jsonl_path: str, |
| model_inference_fn, |
| num_samples: int = 4, |
| min_f1_difference: float = 0.1, |
| temperature: float = 0.8 |
| ): |
| """ |
| Convert SFT dataset to DPO dataset by sampling multiple outputs and ranking by F1. |
| |
| Args: |
| sft_jsonl_path: Path to SFT JSONL file |
| output_jsonl_path: Path to output DPO JSONL file |
| model_inference_fn: Function that takes (prompt, num_samples, temperature) and returns List[str] |
| num_samples: Number of outputs to sample per prompt |
| min_f1_difference: Minimum F1 difference to create a pair |
| temperature: Sampling temperature |
| """ |
| pairs_created = 0 |
| |
| with open(sft_jsonl_path, 'r') as f_in, open(output_jsonl_path, 'w') as f_out: |
| for line in f_in: |
| data = json.loads(line) |
| |
| |
| prompt = data.get("input", "") |
| ground_truth = data.get("output", "") |
| |
| if not prompt or not ground_truth: |
| continue |
| |
| |
| try: |
| generations = model_inference_fn(prompt, num_samples, temperature) |
| except Exception as e: |
| print(f"Error generating outputs: {e}") |
| continue |
| |
| |
| pairs = create_dpo_pairs_from_generations( |
| prompt, generations, ground_truth, min_f1_difference |
| ) |
| |
| |
| for pair in pairs: |
| f_out.write(json.dumps(pair) + '\n') |
| pairs_created += 1 |
| |
| print(f"Created {pairs_created} DPO pairs from {sft_jsonl_path}") |
|
|
|
|
| def prepare_dpo_data_from_instruct( |
| instruct_jsonl: str, |
| output_dpo_jsonl: str, |
| ): |
| """ |
| Simple conversion from instruction data to DPO format. |
| This assumes you already have multiple outputs per input or will generate them. |
| |
| For demonstration, this creates a basic structure. In practice, you need to: |
| 1. Generate multiple outputs for each input |
| 2. Rank them by F1 score |
| 3. Create chosen/rejected pairs |
| """ |
| print(f"Converting {instruct_jsonl} to DPO format...") |
| print("Note: This requires generating multiple outputs per prompt.") |
| print("Use convert_sft_to_dpo_with_sampling() with your model for actual conversion.") |
| |
| |
| with open(instruct_jsonl, 'r') as f: |
| for line in f: |
| data = json.loads(line) |
| print(f"Input: {data.get('input', '')[:100]}...") |
| print(f"Ground truth output available: {len(data.get('output', ''))} chars") |
| print(" -> Need to generate multiple outputs and rank by F1 score") |
| print() |
| break |
|
|
|
|
| if __name__ == "__main__": |
| |
| print("F1 Score Utility for File-Level Ranking") |
| print("=" * 50) |
| |
| |
| ground_truth = """ |
| ##OUTPUT |
| The webhook system requires subscription support. |
| ##SELECT |
| crates/common_enums/src/enums.rs::EventClass |
| crates/router/src/webhooks.rs::process_webhook |
| <EOS> |
| """ |
| |
| prediction1 = """ |
| ##OUTPUT |
| The webhook system requires subscription support. |
| ##SELECT |
| crates/common_enums/src/enums.rs::EventClass |
| crates/router/src/webhooks.rs::process_webhook |
| <EOS> |
| """ |
| |
| prediction2 = """ |
| ##OUTPUT |
| The webhook system requires subscription support. |
| ##SELECT |
| crates/common_enums/src/enums.rs::EventClass |
| crates/router/src/handlers.rs::handle_request |
| <EOS> |
| """ |
| |
| print("\nExample 1: Perfect match") |
| metrics1 = compute_file_level_f1(prediction1, ground_truth) |
| print(f"F1 Score: {metrics1['f1']:.3f}") |
| print(f"Precision: {metrics1['precision']:.3f}, Recall: {metrics1['recall']:.3f}") |
| |
| print("\nExample 2: Partial match") |
| metrics2 = compute_file_level_f1(prediction2, ground_truth) |
| print(f"F1 Score: {metrics2['f1']:.3f}") |
| print(f"Precision: {metrics2['precision']:.3f}, Recall: {metrics2['recall']:.3f}") |
| |
| print("\nExample 3: Ranking outputs") |
| outputs = [prediction1, prediction2] |
| ranked = rank_outputs_by_f1(outputs, ground_truth) |
| print("Ranked outputs:") |
| for i, (output, f1, metrics) in enumerate(ranked, 1): |
| print(f" {i}. F1={f1:.3f} - {metrics['true_positives']} correct files") |
|
|