Sarthak commited on
Commit
ea0b2a0
Β·
1 Parent(s): 0b74f1f

feat: created a cli to manage the complete generation process

Browse files
patches/model2vec.patch ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --- a/model2vec/train/base.py
2
+ +++ b/model2vec/train/base.py
3
+ @@ -35,7 +35,7 @@ class FinetunableStaticModel(nn.Module):
4
+ )
5
+ self.vectors = vectors.float()
6
+
7
+ - self.embeddings = nn.Embedding.from_pretrained(vectors.clone(), freeze=False, padding_idx=pad_id)
8
+ + self.embeddings = nn.Embedding.from_pretrained(self.vectors.clone(), freeze=False, padding_idx=pad_id)
9
+ self.head = self.construct_head()
10
+ self.w = self.construct_weights()
11
+ self.tokenizer = tokenizer
12
+ --- a/model2vec/distill/distillation.py
13
+ +++ b/model2vec/distill/distillation.py
14
+ @@ -137,7 +137,10 @@ def distill_from_model(
15
+ # Get the language from the model card.
16
+ try:
17
+ info = model_info(model_name)
18
+ - language = info.cardData.get("language", None)
19
+ + if info is not None and hasattr(info, 'cardData') and info.cardData is not None:
20
+ + language = info.cardData.get("language", None)
21
+ + else:
22
+ + language = None
23
+ except RepositoryNotFoundError:
24
+ logger.info("No model info found for the model. Setting language to None.")
25
+ language = None
26
+ --- a/model2vec/distill/inference.py
27
+ +++ b/model2vec/distill/inference.py
28
+ @@ -109,5 +109,12 @@ def create_embeddings(
29
+ out_tokens.extend([Token(x, False) for x in tokens])
30
+ out_weights = np.stack(intermediate_weights)
31
+
32
+ + # Validate token-vector consistency to prevent failures
33
+ + if len(out_tokens) != out_weights.shape[0]:
34
+ + logger.warning(f"Token-vector mismatch: {len(out_tokens)} tokens vs {out_weights.shape[0]} vectors. Truncating to prevent failure.")
35
+ + min_count = min(len(out_tokens), out_weights.shape[0])
36
+ + out_tokens = out_tokens[:min_count]
37
+ + out_weights = out_weights[:min_count]
38
+ +
39
+ return out_tokens, out_weights
src/distiller/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ """Model2Vec Distillation Pipeline for gte-Qwen2-7B-instruct."""
2
+
3
+ __version__ = "0.1.0"
4
+
5
+ from .distill import beam_code_distillation, code_specialized_distillation
6
+
7
+ __all__ = ["beam_code_distillation", "code_specialized_distillation"]
src/distiller/__main__.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Main entry point for the distiller package."""
2
+
3
+ import argparse
4
+ import sys
5
+
6
+
7
+ def main() -> None:
8
+ """Main entry point for the distiller package."""
9
+ parser = argparse.ArgumentParser(description="Model2Vec Code-Specialized Distillation Pipeline")
10
+ subparsers = parser.add_subparsers(dest="command", help="Available commands")
11
+
12
+ # Distillation command
13
+ distill_parser = subparsers.add_parser("distill", help="Run code-specialized model distillation")
14
+ distill_parser.add_argument("--model", default="Alibaba-NLP/gte-Qwen2-7B-instruct", help="Model to distill")
15
+ distill_parser.add_argument("--output-dir", default="gte_qwen2_m2v_code", help="Output directory")
16
+ distill_parser.add_argument("--pca-dims", type=int, default=512, help="PCA dimensions")
17
+ distill_parser.add_argument("--max-samples", type=int, default=50000, help="Max CodeSearchNet samples")
18
+ distill_parser.add_argument("--use-beam", action="store_true", help="Use Beam for cloud GPU distillation")
19
+
20
+ # Simplified distillation command
21
+ simple_parser = subparsers.add_parser("distill-simple", help="Run simplified Model2Vec distillation (local)")
22
+ simple_parser.add_argument(
23
+ "--teacher", default="sentence-transformers/all-MiniLM-L6-v2", help="Teacher model to distill from"
24
+ )
25
+ simple_parser.add_argument("--output-dir", default="gte_qwen2_m2v_code_simplified", help="Output directory")
26
+ simple_parser.add_argument("--pca-dims", type=int, default=256, help="PCA dimensions")
27
+
28
+ # CodeSearchNet evaluation command
29
+ evaluate_parser = subparsers.add_parser("evaluate", help="Run CodeSearchNet evaluation on all default models")
30
+ evaluate_parser.add_argument("--use-beam", action="store_true", help="Use Beam for cloud evaluation")
31
+
32
+ # CodeSearchNet evaluation command (simplified models only)
33
+ evaluate_simple_parser = subparsers.add_parser(
34
+ "evaluate-simple", help="Run CodeSearchNet evaluation on simplified models only"
35
+ )
36
+ evaluate_simple_parser.add_argument("--use-beam", action="store_true", help="Use Beam for cloud evaluation")
37
+
38
+ # Analysis command
39
+ analysis_parser = subparsers.add_parser("analyze", help="Generate CodeSearchNet analysis report")
40
+ analysis_parser.add_argument("--results-dir", default="code_evaluation_results", help="Results directory")
41
+ analysis_parser.add_argument("--results-file", help="Single results file to analyze")
42
+ analysis_parser.add_argument("--model-name", default="gte_qwen2_m2v_code", help="Model name for report")
43
+ analysis_parser.add_argument("--output", default="README.md", help="Output report file")
44
+ analysis_parser.add_argument("--export-csv", help="Export comparison results to CSV")
45
+ analysis_parser.add_argument("--use-beam", action="store_true", help="Use Beam for cloud analysis")
46
+
47
+ # Sync command
48
+ sync_parser = subparsers.add_parser("sync", help="Download files from Beam volume to local directory")
49
+ sync_parser.add_argument("--model-files", action="store_true", help="Download final model files")
50
+ sync_parser.add_argument(
51
+ "--analysis-files",
52
+ action="store_true",
53
+ help="Download analysis reports and charts",
54
+ )
55
+ sync_parser.add_argument("--all", action="store_true", help="Download all generated files")
56
+ sync_parser.add_argument("--output-dir", default=".", help="Local output directory")
57
+
58
+ # Benchmark command
59
+ benchmark_parser = subparsers.add_parser("benchmark", help="Run performance benchmarking on all default models")
60
+ benchmark_parser.add_argument("--use-beam", action="store_true", help="Use Beam for cloud benchmarking")
61
+
62
+ # Benchmark command (simplified models only)
63
+ benchmark_simple_parser = subparsers.add_parser(
64
+ "benchmark-simple", help="Run performance benchmarking on simplified models only"
65
+ )
66
+ benchmark_simple_parser.add_argument("--use-beam", action="store_true", help="Use Beam for cloud benchmarking")
67
+
68
+ args = parser.parse_args()
69
+
70
+ if args.command == "distill":
71
+ from .distill_simplified import run_local_distillation, beam_distill_all_teachers
72
+
73
+ if args.use_beam:
74
+ # Run on Beam
75
+ print("Running comprehensive teacher model distillation on Beam...")
76
+ results = beam_distill_all_teachers()
77
+ else:
78
+ # Run locally
79
+ print("Running comprehensive teacher model distillation locally...")
80
+ results = run_local_distillation()
81
+
82
+ print(f"βœ… Distillation complete! Created {results['total_successful']} models")
83
+ print("πŸ“ Models location: ./code_model2vec/final/")
84
+ print("\nβœ… Created models:")
85
+ for model_name in results["successful_models"]:
86
+ model_info = results["all_results"][model_name]
87
+ print(f" β€’ {model_name} (from {model_info['teacher_model']})")
88
+
89
+ elif args.command == "distill-simple":
90
+ from .distill_simplified import run_local_distillation
91
+
92
+ # Run simplified distillation for all teacher models locally
93
+ print("Running comprehensive teacher model distillation locally...")
94
+ results = run_local_distillation()
95
+ print(f"βœ… Distillation complete! Created {results['total_successful']} models")
96
+ print("πŸ“ Models location: ./code_model2vec/final/")
97
+ print("\nβœ… Created models:")
98
+ for model_name in results["successful_models"]:
99
+ model_info = results["all_results"][model_name]
100
+ print(f" β€’ {model_name} (from {model_info['teacher_model']})")
101
+
102
+ elif args.command == "evaluate":
103
+ from .evaluate import main as evaluate_main, run_local_evaluation
104
+
105
+ if args.use_beam:
106
+ # Run on Beam with all default models
107
+ print("Running comprehensive evaluation on Beam...")
108
+ evaluate_main()
109
+ else:
110
+ # Run locally with all default models
111
+ print("Running comprehensive evaluation locally...")
112
+ run_local_evaluation()
113
+
114
+ elif args.command == "evaluate-simple":
115
+ from .evaluate import evaluate_simplified_only, run_local_evaluation_simplified
116
+
117
+ if args.use_beam:
118
+ # Run on Beam with simplified models only
119
+ print("Running simplified model evaluation on Beam...")
120
+ evaluate_simplified_only()
121
+ else:
122
+ # Run locally with simplified models only
123
+ print("Running simplified model evaluation locally...")
124
+ run_local_evaluation_simplified()
125
+
126
+ elif args.command == "analyze":
127
+ from .analyze import main as analyze_main
128
+
129
+ # Run locally - Override sys.argv to pass arguments to the analyze script
130
+ sys.argv = ["analyze.py"]
131
+ if args.results_dir != "code_evaluation_results":
132
+ sys.argv.extend(["--results-dir", args.results_dir])
133
+ if args.results_file:
134
+ sys.argv.extend(["--results-file", args.results_file])
135
+ if args.model_name != "gte_qwen2_m2v_code":
136
+ sys.argv.extend(["--model-name", args.model_name])
137
+ if args.output != "README.md":
138
+ sys.argv.extend(["--output", args.output])
139
+ if args.export_csv:
140
+ sys.argv.extend(["--export-csv", args.export_csv])
141
+ analyze_main()
142
+
143
+ elif args.command == "sync":
144
+ from .sync import sync_files
145
+
146
+ # Run locally
147
+ sync_files(
148
+ model_files=args.model_files,
149
+ analysis_files=args.analysis_files,
150
+ all_files=args.all,
151
+ output_dir=args.output_dir,
152
+ )
153
+
154
+ elif args.command == "benchmark":
155
+ from .benchmark import main as benchmark_main, run_local_benchmark
156
+
157
+ if args.use_beam:
158
+ # Run on Beam with all default models
159
+ print("Running comprehensive benchmarking on Beam...")
160
+ benchmark_main()
161
+ else:
162
+ # Run locally with all default models
163
+ print("Running comprehensive benchmarking locally...")
164
+ run_local_benchmark()
165
+
166
+ elif args.command == "benchmark-simple":
167
+ from .benchmark import benchmark_simplified_only, run_local_benchmark_simplified
168
+
169
+ if args.use_beam:
170
+ # Run on Beam with simplified models only
171
+ print("Running simplified model benchmarking on Beam...")
172
+ benchmark_simplified_only()
173
+ else:
174
+ # Run locally with simplified models only
175
+ print("Running simplified model benchmarking locally...")
176
+ run_local_benchmark_simplified()
177
+
178
+ else:
179
+ parser.print_help()
180
+
181
+
182
+ if __name__ == "__main__":
183
+ main()
src/distiller/analyze.py ADDED
@@ -0,0 +1,1495 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Comprehensive CodeSearchNet Analysis and Reporting Script.
3
+
4
+ This script provides a complete CodeSearchNet evaluation pipeline that includes:
5
+ 1. Model evaluation results analysis
6
+ 2. Peer model comparison analysis
7
+ 3. Advanced visualizations and charts
8
+ 4. Leaderboard comparison and ranking analysis
9
+ 5. Comprehensive README report generation
10
+ 6. Performance efficiency analysis
11
+ 7. Language-specific performance analysis
12
+
13
+ Features:
14
+ - CodeSearchNet-style scoring (NDCG@10, MRR, Recall metrics)
15
+ - Comparison with peer code-specialized models
16
+ - Model efficiency metrics (performance per parameter)
17
+ - Interactive visualizations with Plotly and Matplotlib
18
+ - Professional charts for README integration
19
+ - Statistical analysis of results across programming languages
20
+
21
+ Usage:
22
+ python analyze.py --results-dir results/ --model-name my_model
23
+ distiller analyze --results-dir evaluation_results
24
+ """
25
+
26
+ import argparse
27
+ import json
28
+ import logging
29
+ import time
30
+ from pathlib import Path
31
+ from typing import Any
32
+
33
+ import matplotlib.pyplot as plt
34
+ import numpy as np
35
+ import pandas as pd
36
+ import seaborn as sns
37
+
38
+ # Optional Plotly import with fallback
39
+ PLOTLY_AVAILABLE = True
40
+ try:
41
+ import plotly.graph_objects as go
42
+ except ImportError:
43
+ PLOTLY_AVAILABLE = False
44
+
45
+ # Set plotting style
46
+ try:
47
+ plt.style.use("seaborn-v0_8")
48
+ except OSError:
49
+ plt.style.use("seaborn") # Fallback for older matplotlib versions
50
+ sns.set_palette("husl")
51
+
52
+ # =============================================================================
53
+ # CONFIGURATION
54
+ # =============================================================================
55
+
56
+ # Constants
57
+ MIN_SCORES_FOR_STATS = 2
58
+ HIGH_PERFORMANCE_THRESHOLD = 0.3
59
+ MEDIUM_PERFORMANCE_THRESHOLD = 0.2
60
+
61
+ # Model Configuration
62
+ MODEL_NAME = "code_model2vec_analysis" # Generic name for multi-model analysis
63
+ ORIGINAL_MODEL_NAME = "Alibaba-NLP/gte-Qwen2-7B-instruct"
64
+ OUTPUT_DIR = Path("analysis_results")
65
+ IMAGES_DIR = Path("analysis_charts")
66
+ REPORT_FILE = Path("REPORT.md") # Changed from README.md
67
+
68
+ # Local directories for results - updated for new structure
69
+ DEFAULT_EVALUATION_DIR = "code_model2vec/evaluation_results"
70
+ DEFAULT_BENCHMARK_DIR = "code_model2vec/benchmark_results"
71
+
72
+ # CodeSearchNet Languages
73
+ CODE_LANGUAGES = ["python", "javascript", "java", "php", "ruby", "go"]
74
+
75
+ # Model name mapping from the default models in evaluate.py and benchmark.py
76
+ MODEL_NAME_MAPPING = {
77
+ # File names to display names
78
+ "gte_qwen2_m2v_code": "gte_qwen2_m2v_code (Ours)",
79
+ "all-MiniLM-L6-v2": "sentence-transformers/all-MiniLM-L6-v2",
80
+ "codebert-base": "microsoft/codebert-base",
81
+ "graphcodebert-base": "microsoft/graphcodebert-base",
82
+ "CodeBERTa-small-v1": "huggingface/CodeBERTa-small-v1",
83
+ "all-mpnet-base-v2": "sentence-transformers/all-mpnet-base-v2",
84
+ "all-MiniLM-L12-v2": "sentence-transformers/all-MiniLM-L12-v2",
85
+ "potion-base-8M": "minishlab/potion-base-8M",
86
+ "potion-retrieval-32M": "minishlab/potion-retrieval-32M",
87
+ "codet5-base": "Salesforce/codet5-base",
88
+ }
89
+
90
+ # Reverse mapping for lookups
91
+ DISPLAY_NAME_TO_FILE = {v: k for k, v in MODEL_NAME_MAPPING.items()}
92
+
93
+ # Peer models for comparison (code-specialized models)
94
+ PEER_MODELS = {
95
+ "sentence-transformers/all-MiniLM-L6-v2": {"overall_ndcg": 0.25, "type": "General"},
96
+ "microsoft/codebert-base": {"overall_ndcg": 0.32, "type": "Code-Specific"},
97
+ "microsoft/graphcodebert-base": {"overall_ndcg": 0.35, "type": "Code-Specific"},
98
+ "huggingface/CodeBERTa-small-v1": {"overall_ndcg": 0.28, "type": "Code-Specific"},
99
+ "sentence-transformers/all-mpnet-base-v2": {"overall_ndcg": 0.27, "type": "General"},
100
+ }
101
+
102
+ # Model specifications for efficiency analysis
103
+ MODEL_SPECS = {
104
+ "sentence-transformers/all-MiniLM-L6-v2": {"parameters": 22.7, "size_mb": 90},
105
+ "microsoft/codebert-base": {"parameters": 125.0, "size_mb": 500},
106
+ "microsoft/graphcodebert-base": {"parameters": 125.0, "size_mb": 500},
107
+ "huggingface/CodeBERTa-small-v1": {"parameters": 84.0, "size_mb": 340},
108
+ "sentence-transformers/all-mpnet-base-v2": {"parameters": 109.0, "size_mb": 440},
109
+ "Alibaba-NLP/gte-Qwen2-7B-instruct": {"parameters": 7000.0, "size_mb": 13000},
110
+ }
111
+
112
+ # Distilled model specifications
113
+ DISTILLED_MODEL_SPECS = {
114
+ "parameters": 39.0, # Model2Vec parameters
115
+ "size_mb": 149.0, # Actual model size
116
+ "dimensions": 256, # Model2Vec dimensions
117
+ "original_dimensions": 3584,
118
+ "distillation_method": "Model2Vec",
119
+ "training_dataset": "CodeSearchNet",
120
+ }
121
+
122
+ # =============================================================================
123
+ # UTILITY FUNCTIONS
124
+ # =============================================================================
125
+
126
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
127
+ logger = logging.getLogger(__name__)
128
+
129
+
130
+ def setup_directories(base_path: Path | None = None) -> tuple[Path, Path, Path]:
131
+ """Create necessary directories and return their paths."""
132
+ if base_path:
133
+ output_dir = base_path / "analysis_results"
134
+ images_dir = base_path / "analysis_results" / "charts"
135
+ reports_dir = base_path / "analysis_results" / "reports"
136
+ else:
137
+ output_dir = OUTPUT_DIR
138
+ images_dir = IMAGES_DIR
139
+ reports_dir = OUTPUT_DIR / "reports"
140
+
141
+ output_dir.mkdir(parents=True, exist_ok=True)
142
+ images_dir.mkdir(parents=True, exist_ok=True)
143
+ reports_dir.mkdir(parents=True, exist_ok=True)
144
+
145
+ return output_dir, images_dir, reports_dir
146
+
147
+
148
+ def extract_model_name_from_filename(filename: str) -> str:
149
+ """Extract and map model name from filename."""
150
+ # Remove prefixes and extensions
151
+ name = filename.replace("codesearchnet_eval_", "").replace("benchmark_", "").replace(".json", "")
152
+
153
+ # Check if it's in our mapping
154
+ if name in MODEL_NAME_MAPPING:
155
+ return MODEL_NAME_MAPPING[name]
156
+
157
+ # Try to find partial matches
158
+ for file_key, display_name in MODEL_NAME_MAPPING.items():
159
+ if file_key in name or name in file_key:
160
+ return display_name
161
+
162
+ # If no mapping found, return the cleaned name
163
+ return name
164
+
165
+
166
+ class CodeSearchNetAnalyzer:
167
+ """Analyzer for CodeSearchNet evaluation results and performance benchmarks."""
168
+
169
+ def __init__(
170
+ self,
171
+ results_dir: str | None = None,
172
+ benchmark_dir: str | None = None,
173
+ images_dir: Path | None = None,
174
+ ) -> None:
175
+ """Initialize analyzer with results directories."""
176
+ self.results_dir = Path(results_dir) if results_dir else Path(DEFAULT_EVALUATION_DIR)
177
+ self.benchmark_dir = Path(benchmark_dir) if benchmark_dir else Path(DEFAULT_BENCHMARK_DIR)
178
+ self.images_dir = images_dir or IMAGES_DIR
179
+ self.results: list[dict[str, Any]] = []
180
+ self.benchmark_results: list[dict[str, Any]] = []
181
+ self.comparison_df: pd.DataFrame | None = None
182
+ self.benchmark_df: pd.DataFrame | None = None
183
+
184
+ def load_benchmark_results(self) -> None:
185
+ """Load benchmark results from local directory."""
186
+ logger.info("πŸ“Š Loading benchmark results...")
187
+
188
+ if not self.benchmark_dir.exists():
189
+ logger.warning(f"Benchmark directory not found: {self.benchmark_dir}")
190
+ return
191
+
192
+ logger.info(f"πŸ” Searching for benchmark files in: {self.benchmark_dir}")
193
+ benchmark_files = list(self.benchmark_dir.glob("benchmark_*.json"))
194
+ logger.info(f"πŸ“ Found {len(benchmark_files)} benchmark files")
195
+
196
+ for benchmark_file_path in benchmark_files:
197
+ try:
198
+ logger.info(f"πŸ“– Loading: {benchmark_file_path.name}")
199
+ with benchmark_file_path.open() as f:
200
+ data = json.load(f)
201
+ if data is not None:
202
+ # Update model name with proper mapping
203
+ original_name = data.get("model_name", "Unknown")
204
+ mapped_name = extract_model_name_from_filename(benchmark_file_path.stem)
205
+ data["model_name"] = mapped_name
206
+ data["original_model_name"] = original_name
207
+
208
+ self.benchmark_results.append(data)
209
+ logger.info(f"βœ… Successfully loaded: {mapped_name}")
210
+ except (json.JSONDecodeError, KeyError) as e:
211
+ logger.warning(f"❌ Failed to load {benchmark_file_path}: {e}")
212
+
213
+ logger.info(f"πŸ“Š Total benchmark results loaded: {len(self.benchmark_results)}")
214
+ if self.benchmark_results:
215
+ model_names = [r.get("model_name", "Unknown") for r in self.benchmark_results]
216
+ logger.info(f"🎯 Benchmark models found: {', '.join(model_names)}")
217
+
218
+ self._create_benchmark_dataframe()
219
+
220
+ def _create_benchmark_dataframe(self) -> None:
221
+ """Create benchmark comparison DataFrame from results."""
222
+ if not self.benchmark_results:
223
+ return
224
+
225
+ benchmark_data = []
226
+ for result in self.benchmark_results:
227
+ model_name = result.get("model_name", "Unknown")
228
+ size_metrics = result.get("size_metrics", {})
229
+ speed_benchmarks = result.get("speed_benchmarks", {})
230
+ memory_benchmarks = result.get("memory_benchmarks", {})
231
+ cpu_vs_gpu = result.get("cpu_vs_gpu", {})
232
+
233
+ # Extract key metrics
234
+ row = {
235
+ "Model": model_name,
236
+ "Disk_Size_MB": size_metrics.get("disk_size_mb", 0),
237
+ "Parameters_M": size_metrics.get("parameters_millions", 0),
238
+ "Embedding_Dim": size_metrics.get("embedding_dim", 0),
239
+ "RAM_Usage_MB": size_metrics.get("ram_usage_mb", 0),
240
+ "GPU_Memory_MB": size_metrics.get("gpu_memory_mb", 0),
241
+ }
242
+
243
+ # Speed metrics (medium texts, batch 32)
244
+ if "medium" in speed_benchmarks and "batch_32" in speed_benchmarks["medium"]:
245
+ batch_32 = speed_benchmarks["medium"]["batch_32"]
246
+ row.update(
247
+ {
248
+ "Throughput_TextsPerSec": batch_32.get("texts_per_second", 0),
249
+ "Latency_MsPerText": batch_32.get("time_per_text_ms", 0),
250
+ "TokenSpeed_TokensPerSec": batch_32.get("tokens_per_second", 0),
251
+ }
252
+ )
253
+
254
+ # Memory scaling (batch 32)
255
+ if "batch_32" in memory_benchmarks:
256
+ batch_32_mem = memory_benchmarks["batch_32"]
257
+ if not batch_32_mem.get("oom", False) and "error" not in batch_32_mem:
258
+ row.update(
259
+ {
260
+ "Memory_Used_MB": batch_32_mem.get("memory_used_mb", 0),
261
+ "Memory_Per_Text_MB": batch_32_mem.get("memory_per_text_mb", 0),
262
+ }
263
+ )
264
+
265
+ # CPU vs GPU comparison
266
+ for device in ["cpu", "cuda"]:
267
+ if device in cpu_vs_gpu and "error" not in cpu_vs_gpu[device]:
268
+ device_key = f"{device.upper()}_TextsPerSec"
269
+ row[device_key] = cpu_vs_gpu[device].get("texts_per_second", 0)
270
+
271
+ benchmark_data.append(row)
272
+
273
+ self.benchmark_df = pd.DataFrame(benchmark_data)
274
+
275
+ def load_results(self) -> None:
276
+ """Load evaluation results from local directory."""
277
+ logger.info("πŸ” Loading evaluation results...")
278
+
279
+ if not self.results_dir.exists():
280
+ logger.warning(f"Evaluation directory not found: {self.results_dir}")
281
+ return
282
+
283
+ logger.info(f"πŸ” Searching for evaluation files in: {self.results_dir}")
284
+ json_files = list(self.results_dir.glob("codesearchnet_eval_*.json"))
285
+ logger.info(f"πŸ“ Found {len(json_files)} evaluation files")
286
+
287
+ for json_file in json_files:
288
+ try:
289
+ logger.info(f"πŸ“– Loading: {json_file.name}")
290
+ with json_file.open() as f:
291
+ data = json.load(f)
292
+ if data is not None:
293
+ # Update model name with proper mapping
294
+ original_name = data.get("model_name", "Unknown")
295
+ mapped_name = extract_model_name_from_filename(json_file.stem)
296
+ data["model_name"] = mapped_name
297
+ data["original_model_name"] = original_name
298
+
299
+ self.results.append(data)
300
+ logger.info(f"βœ… Successfully loaded: {mapped_name}")
301
+ except (json.JSONDecodeError, KeyError) as e:
302
+ logger.warning(f"❌ Failed to load {json_file}: {e}")
303
+
304
+ logger.info(f"πŸ“Š Total loaded: {len(self.results)} model results")
305
+ if self.results:
306
+ model_names = [r.get("model_name", "Unknown") for r in self.results]
307
+ logger.info(f"🎯 Models found: {', '.join(model_names)}")
308
+
309
+ self._create_comparison_dataframe()
310
+
311
+ # Also load benchmark results
312
+ self.load_benchmark_results()
313
+
314
+ def _create_comparison_dataframe(self) -> None:
315
+ """Create comparison DataFrame from results."""
316
+ if not self.results:
317
+ return
318
+
319
+ comparison_data = []
320
+ for result in self.results:
321
+ overall = result.get("overall", {})
322
+ row = {
323
+ "Model": result["model_name"],
324
+ "MRR": overall.get("mrr", 0),
325
+ "NDCG@1": overall.get("ndcg@1", 0),
326
+ "NDCG@5": overall.get("ndcg@5", 0),
327
+ "NDCG@10": overall.get("ndcg@10", 0),
328
+ "Recall@1": overall.get("recall@1", 0),
329
+ "Recall@5": overall.get("recall@5", 0),
330
+ "Recall@10": overall.get("recall@10", 0),
331
+ "Mean_Rank": overall.get("mean_rank", 0),
332
+ "Median_Rank": overall.get("median_rank", 0),
333
+ }
334
+ comparison_data.append(row)
335
+
336
+ self.comparison_df = pd.DataFrame(comparison_data)
337
+ if not self.comparison_df.empty:
338
+ self.comparison_df = self.comparison_df.sort_values("NDCG@10", ascending=False)
339
+
340
+ def print_summary(self) -> None:
341
+ """Print summary of results."""
342
+ if not self.results:
343
+ logger.warning("No results to summarize")
344
+ return
345
+
346
+ print(f"\n{'=' * 60}")
347
+ print("CodeSearchNet Evaluation Summary")
348
+ print(f"{'=' * 60}")
349
+ print(f"Total models evaluated: {len(self.results)}")
350
+
351
+ if self.comparison_df is not None and not self.comparison_df.empty:
352
+ print(f"\nTop performing model: {self.comparison_df.iloc[0]['Model']}")
353
+ print(f"Best NDCG@10: {self.comparison_df.iloc[0]['NDCG@10']:.4f}")
354
+ print(f"Best MRR: {self.comparison_df['MRR'].max():.4f}")
355
+
356
+ print(f"\nEvaluated languages: {', '.join(CODE_LANGUAGES)}")
357
+
358
+ # Also print benchmark summary if available
359
+ if self.benchmark_results:
360
+ print(f"\n{'=' * 60}")
361
+ print("Performance Benchmark Summary")
362
+ print(f"{'=' * 60}")
363
+ print(f"Total models benchmarked: {len(self.benchmark_results)}")
364
+
365
+ if self.benchmark_df is not None and not self.benchmark_df.empty:
366
+ # Safely get fastest and smallest models
367
+ fastest_model = "N/A"
368
+ smallest_model = "N/A"
369
+
370
+ if "Throughput_TextsPerSec" in self.benchmark_df.columns:
371
+ fastest_idx = self.benchmark_df["Throughput_TextsPerSec"].idxmax()
372
+ fastest_model = str(self.benchmark_df.loc[fastest_idx, "Model"])
373
+
374
+ if "Disk_Size_MB" in self.benchmark_df.columns:
375
+ smallest_idx = self.benchmark_df["Disk_Size_MB"].idxmin()
376
+ smallest_model = str(self.benchmark_df.loc[smallest_idx, "Model"])
377
+
378
+ print(f"\nFastest model: {fastest_model}")
379
+ print(f"Smallest model: {smallest_model}")
380
+
381
+ def analyze_language_performance(self) -> None:
382
+ """Analyze performance across programming languages."""
383
+ if not self.results:
384
+ return
385
+
386
+ print(f"\n{'=' * 60}")
387
+ print("Language-Specific Performance Analysis")
388
+ print(f"{'=' * 60}")
389
+
390
+ for result in self.results:
391
+ model_name = result["model_name"]
392
+ print(f"\nModel: {model_name}")
393
+ print("-" * 40)
394
+
395
+ languages = result.get("languages", {})
396
+ lang_data = []
397
+
398
+ for lang, lang_results in languages.items():
399
+ metrics = lang_results.get("metrics", {})
400
+ lang_data.append(
401
+ {
402
+ "Language": lang,
403
+ "NDCG@10": metrics.get("ndcg@10", 0),
404
+ "MRR": metrics.get("mrr", 0),
405
+ "Recall@5": metrics.get("recall@5", 0),
406
+ "Queries": lang_results.get("num_queries", 0),
407
+ }
408
+ )
409
+
410
+ if lang_data:
411
+ lang_df = pd.DataFrame(lang_data)
412
+ print(lang_df.to_string(index=False, float_format="%.4f"))
413
+ print(f"\nBest language: {lang_df.loc[lang_df['NDCG@10'].idxmax(), 'Language']}")
414
+ print(f"Average NDCG@10: {lang_df['NDCG@10'].mean():.4f}")
415
+ print(f"Average queries per language: {lang_df['Queries'].mean():.0f}")
416
+
417
+ def analyze_benchmark_performance(self) -> None:
418
+ """Analyze and print benchmark performance summary."""
419
+ if not self.benchmark_results:
420
+ logger.warning("No benchmark results to analyze")
421
+ return
422
+
423
+ print(f"\n{'=' * 60}")
424
+ print("Performance Benchmark Analysis")
425
+ print(f"{'=' * 60}")
426
+
427
+ for result in self.benchmark_results:
428
+ model_name = result.get("model_name", "Unknown")
429
+ print(f"\nModel: {model_name}")
430
+ print("-" * 40)
431
+
432
+ # Size metrics
433
+ size_metrics = result.get("size_metrics", {})
434
+ if size_metrics:
435
+ print("πŸ“ Model Size:")
436
+ print(f" Disk Size: {size_metrics.get('disk_size_mb', 0):.1f} MB")
437
+ if "parameters_millions" in size_metrics:
438
+ print(f" Parameters: {size_metrics['parameters_millions']:.1f}M")
439
+ if "embedding_dim" in size_metrics:
440
+ print(f" Embedding Dimension: {size_metrics['embedding_dim']}")
441
+
442
+ # Speed metrics
443
+ speed_benchmarks = result.get("speed_benchmarks", {})
444
+ if "medium" in speed_benchmarks and "batch_32" in speed_benchmarks["medium"]:
445
+ batch_32 = speed_benchmarks["medium"]["batch_32"]
446
+ print("⚑ Performance (Batch 32, Medium Texts):")
447
+ print(f" Throughput: {batch_32.get('texts_per_second', 0):.1f} texts/sec")
448
+ print(f" Latency: {batch_32.get('time_per_text_ms', 0):.1f} ms/text")
449
+ print(f" Token Speed: {batch_32.get('tokens_per_second', 0):.0f} tokens/sec")
450
+
451
+ # CPU vs GPU
452
+ cpu_vs_gpu = result.get("cpu_vs_gpu", {})
453
+ if cpu_vs_gpu:
454
+ print("πŸ–₯️ CPU vs GPU:")
455
+ for device, metrics in cpu_vs_gpu.items():
456
+ if "error" not in metrics:
457
+ print(f" {device.upper()}: {metrics.get('texts_per_second', 0):.1f} texts/sec")
458
+
459
+ # Memory efficiency
460
+ memory_benchmarks = result.get("memory_benchmarks", {})
461
+ if "batch_32" in memory_benchmarks:
462
+ batch_32_mem = memory_benchmarks["batch_32"]
463
+ if not batch_32_mem.get("oom", False) and "error" not in batch_32_mem:
464
+ print("πŸ’Ύ Memory Usage (Batch 32):")
465
+ print(f" Total: {batch_32_mem.get('memory_used_mb', 0):.1f} MB")
466
+ print(f" Per Text: {batch_32_mem.get('memory_per_text_mb', 0):.2f} MB")
467
+
468
+ def create_performance_radar_chart(self, model_name: str, language_scores: dict[str, float]) -> str:
469
+ """Create radar chart showing performance across languages."""
470
+ if not PLOTLY_AVAILABLE:
471
+ logger.warning("Plotly not available, skipping radar chart")
472
+ return ""
473
+
474
+ languages = list(language_scores.keys())
475
+ scores = list(language_scores.values())
476
+
477
+ if not languages:
478
+ return ""
479
+
480
+ # Close the radar chart
481
+ languages_closed = [*languages, languages[0]]
482
+ scores_closed = [*scores, scores[0]]
483
+
484
+ fig = go.Figure()
485
+
486
+ fig.add_trace(
487
+ go.Scatterpolar(
488
+ r=scores_closed,
489
+ theta=languages_closed,
490
+ fill="toself",
491
+ name=model_name,
492
+ line_color="rgb(67, 147, 195)",
493
+ fillcolor="rgba(67, 147, 195, 0.3)",
494
+ )
495
+ )
496
+
497
+ fig.update_layout(
498
+ polar={"radialaxis": {"visible": True, "range": [0, max(scores) * 1.1]}},
499
+ showlegend=True,
500
+ title=f"CodeSearchNet Performance by Language: {model_name}",
501
+ width=800,
502
+ height=600,
503
+ )
504
+
505
+ static_path = self.images_dir / "code_performance_radar.png"
506
+ try:
507
+ fig.write_image(str(static_path), width=800, height=600, scale=2)
508
+ return str(static_path)
509
+ except Exception as e:
510
+ logger.warning(f"Could not create static image: {e}")
511
+ return ""
512
+
513
+ def create_comparative_radar_chart(self, simplified_models: list, peer_models: list) -> str:
514
+ """Create comparative radar chart between best distilled model and top peer models."""
515
+ if not PLOTLY_AVAILABLE:
516
+ logger.warning("Plotly not available, skipping comparative radar chart")
517
+ return ""
518
+
519
+ if not simplified_models:
520
+ return ""
521
+
522
+ # Get the best simplified model
523
+ best_simplified = max(simplified_models, key=lambda x: x.get("overall", {}).get("ndcg@10", 0))
524
+
525
+ # Get top 3 peer models by performance
526
+ peer_models_sorted = sorted(peer_models, key=lambda x: x.get("overall", {}).get("ndcg@10", 0), reverse=True)
527
+ top_peers = peer_models_sorted[:3]
528
+
529
+ models_to_compare = [best_simplified, *top_peers]
530
+
531
+ fig = go.Figure()
532
+
533
+ # Define colors for each model
534
+ colors = ["rgb(255, 99, 132)", "rgb(54, 162, 235)", "rgb(255, 205, 86)", "rgb(75, 192, 192)"]
535
+
536
+ for i, model_result in enumerate(models_to_compare):
537
+ model_name = model_result["model_name"]
538
+ languages = model_result.get("languages", {})
539
+
540
+ # Calculate language scores
541
+ language_scores = {}
542
+ for lang, lang_data in languages.items():
543
+ metrics = lang_data.get("metrics", {})
544
+ language_scores[lang.title()] = metrics.get("ndcg@10", 0)
545
+
546
+ if language_scores:
547
+ languages_list = list(language_scores.keys())
548
+ scores_list = list(language_scores.values())
549
+
550
+ # Close the radar chart
551
+ languages_closed = [*languages_list, languages_list[0]]
552
+ scores_closed = [*scores_list, scores_list[0]]
553
+
554
+ # Determine line style - solid for best distilled, dash for peers
555
+ line_dash = "solid" if i == 0 else "dash"
556
+ line_width = 3 if i == 0 else 2
557
+
558
+ fig.add_trace(
559
+ go.Scatterpolar(
560
+ r=scores_closed,
561
+ theta=languages_closed,
562
+ fill="toself" if i == 0 else "none",
563
+ name=model_name,
564
+ line={"color": colors[i % len(colors)], "dash": line_dash, "width": line_width},
565
+ fillcolor=f"rgba{colors[i % len(colors)][3:-1]}, 0.2)" if i == 0 else None,
566
+ )
567
+ )
568
+
569
+ fig.update_layout(
570
+ polar={"radialaxis": {"visible": True, "range": [0, 0.5]}}, # Adjust max range as needed
571
+ showlegend=True,
572
+ title="Model Comparison: Best Distilled vs Top Peer Models",
573
+ width=900,
574
+ height=700,
575
+ )
576
+
577
+ static_path = self.images_dir / "comparative_radar.png"
578
+ try:
579
+ fig.write_image(str(static_path), width=900, height=700, scale=2)
580
+ return str(static_path)
581
+ except Exception as e:
582
+ logger.warning(f"Could not create comparative radar chart: {e}")
583
+ return ""
584
+
585
+ def create_individual_radar_charts(self, simplified_models: list) -> dict[str, str]:
586
+ """Create individual radar charts for all simplified models."""
587
+ radar_charts = {}
588
+
589
+ for result in simplified_models:
590
+ model_name = result["model_name"]
591
+ model_languages = result.get("languages", {})
592
+ model_language_scores = {}
593
+ for lang, lang_data in model_languages.items():
594
+ metrics = lang_data.get("metrics", {})
595
+ model_language_scores[lang.title()] = metrics.get("ndcg@10", 0)
596
+
597
+ if model_language_scores:
598
+ # Create unique filename for each model
599
+ safe_model_name = "".join(c for c in model_name if c.isalnum() or c in ("-", "_")).rstrip()
600
+ radar_chart_path = self.create_performance_radar_chart_individual(
601
+ model_name, model_language_scores, safe_model_name
602
+ )
603
+ if radar_chart_path:
604
+ radar_charts[model_name] = radar_chart_path
605
+
606
+ return radar_charts
607
+
608
+ def create_performance_radar_chart_individual(
609
+ self, model_name: str, language_scores: dict[str, float], filename_suffix: str
610
+ ) -> str:
611
+ """Create radar chart for individual model with unique filename."""
612
+ if not PLOTLY_AVAILABLE:
613
+ logger.warning("Plotly not available, skipping radar chart")
614
+ return ""
615
+
616
+ languages = list(language_scores.keys())
617
+ scores = list(language_scores.values())
618
+
619
+ if not languages:
620
+ return ""
621
+
622
+ # Close the radar chart
623
+ languages_closed = [*languages, languages[0]]
624
+ scores_closed = [*scores, scores[0]]
625
+
626
+ fig = go.Figure()
627
+
628
+ fig.add_trace(
629
+ go.Scatterpolar(
630
+ r=scores_closed,
631
+ theta=languages_closed,
632
+ fill="toself",
633
+ name=model_name,
634
+ line_color="rgb(67, 147, 195)",
635
+ fillcolor="rgba(67, 147, 195, 0.3)",
636
+ )
637
+ )
638
+
639
+ fig.update_layout(
640
+ polar={"radialaxis": {"visible": True, "range": [0, max(scores) * 1.1]}},
641
+ showlegend=True,
642
+ title=f"CodeSearchNet Performance by Language: {model_name}",
643
+ width=800,
644
+ height=600,
645
+ )
646
+
647
+ static_path = self.images_dir / f"radar_{filename_suffix}.png"
648
+ try:
649
+ fig.write_image(str(static_path), width=800, height=600, scale=2)
650
+ return str(static_path)
651
+ except Exception as e:
652
+ logger.warning(f"Could not create static image for {model_name}: {e}")
653
+ return ""
654
+
655
+ def plot_model_comparison(self, save_path: str | None = None) -> str:
656
+ """Create comparison plots for models."""
657
+ if self.comparison_df is None or self.comparison_df.empty:
658
+ logger.warning("No comparison data available for plotting")
659
+ return ""
660
+
661
+ fig, axes = plt.subplots(2, 2, figsize=(15, 12))
662
+ fig.suptitle("CodeSearchNet Model Comparison", fontsize=16, fontweight="bold")
663
+
664
+ # NDCG@10 comparison
665
+ axes[0, 0].barh(self.comparison_df["Model"], self.comparison_df["NDCG@10"])
666
+ axes[0, 0].set_title("NDCG@10 Comparison")
667
+ axes[0, 0].set_xlabel("NDCG@10")
668
+
669
+ # MRR comparison
670
+ axes[0, 1].barh(self.comparison_df["Model"], self.comparison_df["MRR"])
671
+ axes[0, 1].set_title("Mean Reciprocal Rank (MRR)")
672
+ axes[0, 1].set_xlabel("MRR")
673
+
674
+ # Recall@5 comparison
675
+ axes[1, 0].barh(self.comparison_df["Model"], self.comparison_df["Recall@5"])
676
+ axes[1, 0].set_title("Recall@5")
677
+ axes[1, 0].set_xlabel("Recall@5")
678
+
679
+ # Mean Rank comparison (lower is better)
680
+ axes[1, 1].barh(self.comparison_df["Model"], self.comparison_df["Mean_Rank"])
681
+ axes[1, 1].set_title("Mean Rank (lower is better)")
682
+ axes[1, 1].set_xlabel("Mean Rank")
683
+
684
+ plt.tight_layout()
685
+
686
+ output_path = save_path or str(self.images_dir / "model_comparison.png")
687
+ plt.savefig(output_path, dpi=300, bbox_inches="tight")
688
+ plt.close()
689
+
690
+ return output_path
691
+
692
+ def plot_language_heatmap(self, save_path: str | None = None) -> str:
693
+ """Create a heatmap of performance across languages."""
694
+ if not self.results:
695
+ return ""
696
+
697
+ # Prepare data for heatmap
698
+ heatmap_data = []
699
+ for result in self.results:
700
+ model_name = result["model_name"]
701
+ languages = result.get("languages", {})
702
+
703
+ row = {"Model": model_name}
704
+ for lang in CODE_LANGUAGES:
705
+ if lang in languages:
706
+ metrics = languages[lang].get("metrics", {})
707
+ row[lang.title()] = metrics.get("ndcg@10", 0)
708
+ else:
709
+ row[lang.title()] = 0
710
+ heatmap_data.append(row)
711
+
712
+ if not heatmap_data:
713
+ return ""
714
+
715
+ df = pd.DataFrame(heatmap_data).set_index("Model")
716
+
717
+ plt.figure(figsize=(12, 8))
718
+ sns.heatmap(
719
+ df,
720
+ annot=True,
721
+ fmt=".3f",
722
+ cmap="RdYlBu_r",
723
+ center=0.2,
724
+ vmin=0,
725
+ vmax=df.to_numpy().max(),
726
+ cbar_kws={"label": "NDCG@10 Score"},
727
+ )
728
+
729
+ plt.title(
730
+ "CodeSearchNet Performance Heatmap by Language",
731
+ fontsize=16,
732
+ fontweight="bold",
733
+ )
734
+ plt.xlabel("Programming Language", fontsize=12)
735
+ plt.ylabel("Model", fontsize=12)
736
+ plt.tight_layout()
737
+
738
+ output_path = save_path or str(self.images_dir / "language_heatmap.png")
739
+ plt.savefig(output_path, dpi=300, bbox_inches="tight")
740
+ plt.close()
741
+
742
+ return output_path
743
+
744
+ def plot_benchmark_performance(self, save_path: str | None = None) -> str:
745
+ """Create comprehensive benchmark performance plots."""
746
+ if not self.benchmark_results:
747
+ logger.warning("No benchmark data available for plotting")
748
+ return ""
749
+
750
+ fig, axes = plt.subplots(2, 3, figsize=(18, 12))
751
+ fig.suptitle("Performance Benchmark Analysis", fontsize=16, fontweight="bold")
752
+
753
+ # 1. Model Size Comparison
754
+ if self.benchmark_df is not None and "Disk_Size_MB" in self.benchmark_df.columns:
755
+ axes[0, 0].barh(self.benchmark_df["Model"], self.benchmark_df["Disk_Size_MB"])
756
+ axes[0, 0].set_title("Model Size (MB)")
757
+ axes[0, 0].set_xlabel("Size (MB)")
758
+
759
+ # 2. Inference Throughput
760
+ if self.benchmark_df is not None and "Throughput_TextsPerSec" in self.benchmark_df.columns:
761
+ axes[0, 1].barh(self.benchmark_df["Model"], self.benchmark_df["Throughput_TextsPerSec"])
762
+ axes[0, 1].set_title("Inference Throughput")
763
+ axes[0, 1].set_xlabel("Texts/Second")
764
+
765
+ # 3. Memory Usage
766
+ if self.benchmark_df is not None and "Memory_Used_MB" in self.benchmark_df.columns:
767
+ axes[0, 2].barh(self.benchmark_df["Model"], self.benchmark_df["Memory_Used_MB"])
768
+ axes[0, 2].set_title("Memory Usage (Batch 32)")
769
+ axes[0, 2].set_xlabel("Memory (MB)")
770
+
771
+ # 4. Latency Comparison
772
+ if self.benchmark_df is not None and "Latency_MsPerText" in self.benchmark_df.columns:
773
+ axes[1, 0].barh(self.benchmark_df["Model"], self.benchmark_df["Latency_MsPerText"])
774
+ axes[1, 0].set_title("Inference Latency")
775
+ axes[1, 0].set_xlabel("Milliseconds/Text")
776
+
777
+ # 5. CPU vs GPU Performance
778
+ if self.benchmark_df is not None:
779
+ cpu_col = "CPU_TextsPerSec"
780
+ gpu_col = "CUDA_TextsPerSec"
781
+ if cpu_col in self.benchmark_df.columns and gpu_col in self.benchmark_df.columns:
782
+ x = np.arange(len(self.benchmark_df))
783
+ width = 0.35
784
+ axes[1, 1].bar(x - width / 2, self.benchmark_df[cpu_col], width, label="CPU", alpha=0.7)
785
+ axes[1, 1].bar(x + width / 2, self.benchmark_df[gpu_col], width, label="GPU", alpha=0.7)
786
+ axes[1, 1].set_title("CPU vs GPU Performance")
787
+ axes[1, 1].set_ylabel("Texts/Second")
788
+ axes[1, 1].set_xticks(x)
789
+ axes[1, 1].set_xticklabels(self.benchmark_df["Model"], rotation=45, ha="right")
790
+ axes[1, 1].legend()
791
+
792
+ # 6. Parameter Efficiency
793
+ if (
794
+ self.benchmark_df is not None
795
+ and "Parameters_M" in self.benchmark_df.columns
796
+ and "Throughput_TextsPerSec" in self.benchmark_df.columns
797
+ ):
798
+ # Efficiency = Throughput / Parameters (higher is better)
799
+ efficiency = self.benchmark_df["Throughput_TextsPerSec"] / (self.benchmark_df["Parameters_M"] + 1e-6)
800
+ axes[1, 2].barh(self.benchmark_df["Model"], efficiency)
801
+ axes[1, 2].set_title("Parameter Efficiency")
802
+ axes[1, 2].set_xlabel("Texts/Sec per Million Parameters")
803
+
804
+ plt.tight_layout()
805
+
806
+ output_path = save_path or str(self.images_dir / "benchmark_performance.png")
807
+ plt.savefig(output_path, dpi=300, bbox_inches="tight")
808
+ plt.close()
809
+
810
+ return output_path
811
+
812
+ def plot_batch_size_scaling(self, save_path: str | None = None) -> str:
813
+ """Create batch size scaling analysis plot."""
814
+ if not self.benchmark_results:
815
+ return ""
816
+
817
+ plt.figure(figsize=(12, 8))
818
+
819
+ for result in self.benchmark_results:
820
+ model_name = result.get("model_name", "Unknown")
821
+ speed_benchmarks = result.get("speed_benchmarks", {})
822
+
823
+ # Extract batch size performance for medium texts
824
+ if "medium" in speed_benchmarks:
825
+ batch_sizes = []
826
+ throughputs = []
827
+
828
+ for batch_key, metrics in speed_benchmarks["medium"].items():
829
+ if batch_key.startswith("batch_"):
830
+ batch_size = int(batch_key.split("_")[1])
831
+ throughput = metrics.get("texts_per_second", 0)
832
+ batch_sizes.append(batch_size)
833
+ throughputs.append(throughput)
834
+
835
+ if batch_sizes:
836
+ plt.plot(batch_sizes, throughputs, marker="o", label=model_name, linewidth=2)
837
+
838
+ plt.xlabel("Batch Size", fontsize=12)
839
+ plt.ylabel("Throughput (Texts/Second)", fontsize=12)
840
+ plt.title("Batch Size Scaling Performance", fontsize=16, fontweight="bold")
841
+ plt.legend()
842
+ plt.grid(visible=True, alpha=0.3)
843
+ plt.xscale("log", base=2)
844
+
845
+ output_path = save_path or str(self.images_dir / "batch_size_scaling.png")
846
+ plt.savefig(output_path, dpi=300, bbox_inches="tight")
847
+ plt.close()
848
+
849
+ return output_path
850
+
851
+ def plot_memory_scaling(self, save_path: str | None = None) -> str:
852
+ """Create memory scaling analysis plot."""
853
+ if not self.benchmark_results:
854
+ return ""
855
+
856
+ plt.figure(figsize=(12, 8))
857
+
858
+ for result in self.benchmark_results:
859
+ model_name = result.get("model_name", "Unknown")
860
+ memory_benchmarks = result.get("memory_benchmarks", {})
861
+
862
+ batch_sizes = []
863
+ memory_usage = []
864
+
865
+ for batch_key, metrics in memory_benchmarks.items():
866
+ if batch_key.startswith("batch_") and not metrics.get("oom", False) and "error" not in metrics:
867
+ batch_size = int(batch_key.split("_")[1])
868
+ memory_mb = metrics.get("memory_used_mb", 0)
869
+ batch_sizes.append(batch_size)
870
+ memory_usage.append(memory_mb)
871
+
872
+ if batch_sizes:
873
+ plt.plot(batch_sizes, memory_usage, marker="s", label=model_name, linewidth=2)
874
+
875
+ plt.xlabel("Batch Size", fontsize=12)
876
+ plt.ylabel("Memory Usage (MB)", fontsize=12)
877
+ plt.title("Memory Scaling by Batch Size", fontsize=16, fontweight="bold")
878
+ plt.legend()
879
+ plt.grid(visible=True, alpha=0.3)
880
+ plt.xscale("log", base=2)
881
+
882
+ output_path = save_path or str(self.images_dir / "memory_scaling.png")
883
+ plt.savefig(output_path, dpi=300, bbox_inches="tight")
884
+ plt.close()
885
+
886
+ return output_path
887
+
888
+ def create_peer_comparison_chart(self, model_name: str) -> str:
889
+ """Create comparison chart using actual evaluation results."""
890
+ if self.comparison_df is None or self.comparison_df.empty:
891
+ logger.warning("No comparison data available for peer comparison chart")
892
+ return ""
893
+
894
+ # Use actual evaluation results instead of hardcoded scores
895
+ df_sorted = self.comparison_df.sort_values("NDCG@10", ascending=True)
896
+
897
+ plt.figure(figsize=(12, 8))
898
+
899
+ # Color models differently - highlight the user's model
900
+ colors = []
901
+ for model in df_sorted["Model"]:
902
+ if model_name.lower() in model.lower() or "gte_qwen2_m2v_code" in model.lower():
903
+ colors.append("red") # User's model
904
+ else:
905
+ colors.append("skyblue") # Peer models
906
+
907
+ bars = plt.barh(df_sorted["Model"], df_sorted["NDCG@10"], color=colors)
908
+
909
+ # Highlight current model with special formatting
910
+ for i, model in enumerate(df_sorted["Model"]):
911
+ if model_name.lower() in model.lower() or "gte_qwen2_m2v_code" in model.lower():
912
+ bars[i].set_alpha(0.8)
913
+ bars[i].set_edgecolor("black")
914
+ bars[i].set_linewidth(2)
915
+
916
+ plt.xlabel("NDCG@10 Score", fontsize=12)
917
+ plt.title(
918
+ "CodeSearchNet Model Comparison (Actual Results)",
919
+ fontsize=16,
920
+ fontweight="bold",
921
+ )
922
+ plt.grid(axis="x", alpha=0.3)
923
+
924
+ # Add score labels
925
+ for i, score in enumerate(df_sorted["NDCG@10"]):
926
+ plt.text(score + 0.005, i, f"{score:.3f}", va="center")
927
+
928
+ plt.tight_layout()
929
+
930
+ output_path = self.images_dir / "peer_comparison.png"
931
+ plt.savefig(output_path, dpi=300, bbox_inches="tight")
932
+ plt.close()
933
+
934
+ return str(output_path)
935
+
936
+ def create_efficiency_analysis(self, model_name: str) -> str:
937
+ """Create efficiency analysis chart using actual evaluation results."""
938
+ if self.comparison_df is None or self.comparison_df.empty:
939
+ logger.warning("No comparison data available for efficiency analysis")
940
+ return ""
941
+
942
+ models = []
943
+ scores = []
944
+ params = []
945
+ is_user_model = []
946
+
947
+ # Process all evaluated models
948
+ for _, row in self.comparison_df.iterrows():
949
+ model_display_name = row["Model"]
950
+ current_model_score = row["NDCG@10"]
951
+
952
+ # Determine if this is the user's model
953
+ is_users = (
954
+ model_name.lower() in model_display_name.lower() or "gte_qwen2_m2v_code" in model_display_name.lower()
955
+ )
956
+
957
+ if is_users:
958
+ # User's distilled model
959
+ models.append(model_display_name)
960
+ # Safe conversion to float for pandas values
961
+ score_value = pd.to_numeric(current_model_score, errors="coerce")
962
+ scores.append(float(score_value) if not pd.isna(score_value) else 0.0)
963
+ # Safe conversion for DISTILLED_MODEL_SPECS parameters
964
+ param_value = DISTILLED_MODEL_SPECS.get("parameters", 39)
965
+ params.append(float(param_value) if isinstance(param_value, (int, float)) else 39.0)
966
+ is_user_model.append(True)
967
+ else:
968
+ # Find corresponding peer model specs
969
+ model_key = None
970
+ for peer_key in MODEL_SPECS:
971
+ peer_short_name = peer_key.split("/")[-1].lower()
972
+ if peer_short_name in model_display_name.lower():
973
+ model_key = peer_key
974
+ break
975
+
976
+ if model_key and model_key in MODEL_SPECS:
977
+ models.append(model_display_name.split("/")[-1]) # Short name
978
+ # Safe conversion to float for pandas values
979
+ score_value = pd.to_numeric(current_model_score, errors="coerce")
980
+ scores.append(float(score_value) if not pd.isna(score_value) else 0.0)
981
+ params.append(float(MODEL_SPECS[model_key].get("parameters", 100)))
982
+ is_user_model.append(False)
983
+
984
+ if not models:
985
+ logger.warning("No models with parameter specifications found")
986
+ return ""
987
+
988
+ plt.figure(figsize=(12, 8))
989
+
990
+ # Plot peer models
991
+ peer_models = [m for i, m in enumerate(models) if not is_user_model[i]]
992
+ peer_params = [p for i, p in enumerate(params) if not is_user_model[i]]
993
+ peer_scores = [s for i, s in enumerate(scores) if not is_user_model[i]]
994
+
995
+ if peer_models:
996
+ plt.scatter(
997
+ peer_params,
998
+ peer_scores,
999
+ s=100,
1000
+ alpha=0.6,
1001
+ label="Peer Models",
1002
+ color="skyblue",
1003
+ )
1004
+
1005
+ # Plot user's model
1006
+ user_models = [m for i, m in enumerate(models) if is_user_model[i]]
1007
+ user_params = [p for i, p in enumerate(params) if is_user_model[i]]
1008
+ user_scores = [s for i, s in enumerate(scores) if is_user_model[i]]
1009
+
1010
+ if user_models:
1011
+ plt.scatter(
1012
+ user_params,
1013
+ user_scores,
1014
+ s=200,
1015
+ color="red",
1016
+ alpha=0.8,
1017
+ label=f"{user_models[0]} (Distilled)",
1018
+ marker="*",
1019
+ )
1020
+
1021
+ # Add model labels
1022
+ for i, (model, param, score) in enumerate(zip(models, params, scores, strict=False)):
1023
+ if is_user_model[i]:
1024
+ plt.annotate(
1025
+ model,
1026
+ (param, score),
1027
+ xytext=(10, 10),
1028
+ textcoords="offset points",
1029
+ fontweight="bold",
1030
+ color="red",
1031
+ )
1032
+ else:
1033
+ plt.annotate(
1034
+ model,
1035
+ (param, score),
1036
+ xytext=(5, 5),
1037
+ textcoords="offset points",
1038
+ fontsize=9,
1039
+ )
1040
+
1041
+ plt.xlabel("Model Size (Million Parameters)", fontsize=12)
1042
+ plt.ylabel("NDCG@10 Score", fontsize=12)
1043
+ plt.title(
1044
+ "Model Efficiency: Performance vs Size (Actual Results)",
1045
+ fontsize=16,
1046
+ fontweight="bold",
1047
+ )
1048
+ plt.legend()
1049
+ plt.grid(visible=True, alpha=0.3)
1050
+ plt.xscale("log")
1051
+
1052
+ plt.tight_layout()
1053
+
1054
+ output_path = self.images_dir / "efficiency_analysis.png"
1055
+ plt.savefig(output_path, dpi=300, bbox_inches="tight")
1056
+ plt.close()
1057
+
1058
+ return str(output_path)
1059
+
1060
+ def generate_comprehensive_report(self, model_name: str = "Simplified Distillation Models") -> str:
1061
+ """Generate comprehensive markdown report for all evaluated models."""
1062
+ if not self.results:
1063
+ logger.error("No results to analyze")
1064
+ return ""
1065
+
1066
+ # Find all simplified distillation models
1067
+ simplified_models = []
1068
+ peer_models = []
1069
+
1070
+ for result in self.results:
1071
+ result_model_name = result["model_name"]
1072
+ if (
1073
+ "code_model2vec" in result_model_name.lower()
1074
+ or "distilled" in result_model_name.lower()
1075
+ or "(ours)" in result_model_name.lower()
1076
+ ):
1077
+ simplified_models.append(result)
1078
+ else:
1079
+ peer_models.append(result)
1080
+
1081
+ # Get the best performing simplified model for main analysis
1082
+ if simplified_models:
1083
+ main_result = max(simplified_models, key=lambda x: x.get("overall", {}).get("ndcg@10", 0))
1084
+ main_model_name = main_result["model_name"]
1085
+ else:
1086
+ # Fallback to first result if no simplified models found
1087
+ main_result = self.results[0]
1088
+ main_model_name = main_result["model_name"]
1089
+
1090
+ overall = main_result.get("overall", {})
1091
+ languages = main_result.get("languages", {})
1092
+
1093
+ # Calculate language scores for radar chart
1094
+ language_scores = {}
1095
+ for lang, lang_data in languages.items():
1096
+ metrics = lang_data.get("metrics", {})
1097
+ language_scores[lang.title()] = metrics.get("ndcg@10", 0)
1098
+
1099
+ # Create visualizations
1100
+ logger.info("Generating visualizations...")
1101
+ setup_directories()
1102
+
1103
+ self.create_performance_radar_chart(main_model_name, language_scores)
1104
+ comparison_chart = self.plot_model_comparison()
1105
+ heatmap_chart = self.plot_language_heatmap()
1106
+ peer_chart = self.create_peer_comparison_chart(main_model_name)
1107
+ efficiency_chart = self.create_efficiency_analysis(main_model_name)
1108
+
1109
+ # Generate individual radar charts for all simplified models
1110
+ individual_radar_charts = self.create_individual_radar_charts(simplified_models)
1111
+
1112
+ # Create comparative radar chart (best distilled vs top peer models)
1113
+ comparative_radar_chart = self.create_comparative_radar_chart(simplified_models, peer_models)
1114
+
1115
+ # Create benchmark visualizations
1116
+ benchmark_chart = ""
1117
+ batch_scaling_chart = ""
1118
+ memory_scaling_chart = ""
1119
+ if self.benchmark_results:
1120
+ benchmark_chart = self.plot_benchmark_performance()
1121
+ batch_scaling_chart = self.plot_batch_size_scaling()
1122
+ memory_scaling_chart = self.plot_memory_scaling()
1123
+
1124
+ # Generate report
1125
+ report = f"""# Code-Specialized Model2Vec Distillation Analysis
1126
+
1127
+ ## 🎯 Executive Summary
1128
+
1129
+ This report presents a comprehensive analysis of Model2Vec distillation experiments using different teacher models for code-specialized embedding generation.
1130
+
1131
+ ### Evaluated Models Overview
1132
+
1133
+ **Simplified Distillation Models:** {len(simplified_models)}
1134
+ **Peer Comparison Models:** {len(peer_models)}
1135
+ **Total Models Analyzed:** {len(self.results)}
1136
+
1137
+ ### Best Performing Simplified Model: {main_model_name}
1138
+
1139
+ **Overall CodeSearchNet Performance:**
1140
+ - **NDCG@10**: {overall.get("ndcg@10", 0):.4f}
1141
+ - **Mean Reciprocal Rank (MRR)**: {overall.get("mrr", 0):.4f}
1142
+ - **Recall@5**: {overall.get("recall@5", 0):.4f}
1143
+ - **Mean Rank**: {overall.get("mean_rank", 0):.1f}
1144
+
1145
+ ## πŸ“Š Comprehensive Model Comparison
1146
+
1147
+ ### All Simplified Distillation Models Performance
1148
+
1149
+ """
1150
+
1151
+ # Add table of all simplified models
1152
+ if simplified_models:
1153
+ report += "| Model | Teacher | NDCG@10 | MRR | Recall@5 | Status |\n"
1154
+ report += "|-------|---------|---------|-----|----------|--------|\n"
1155
+
1156
+ # Sort by performance
1157
+ simplified_models_sorted = sorted(
1158
+ simplified_models, key=lambda x: x.get("overall", {}).get("ndcg@10", 0), reverse=True
1159
+ )
1160
+
1161
+ for rank, result in enumerate(simplified_models_sorted, 1):
1162
+ model_display = result["model_name"]
1163
+ overall_metrics = result.get("overall", {})
1164
+
1165
+ # Extract teacher model name from model name
1166
+ teacher = "Unknown"
1167
+ if "all_MiniLM_L6_v2" in model_display:
1168
+ teacher = "all-MiniLM-L6-v2"
1169
+ elif "codebert_base" in model_display:
1170
+ teacher = "codebert-base"
1171
+ elif "graphcodebert_base" in model_display:
1172
+ teacher = "graphcodebert-base"
1173
+ elif "gte_Qwen2_7B_instruct" in model_display:
1174
+ teacher = "gte-Qwen2-7B-instruct"
1175
+ elif "all_mpnet_base_v2" in model_display:
1176
+ teacher = "all-mpnet-base-v2"
1177
+
1178
+ status = "πŸ₯‡ Best" if rank == 1 else "πŸ₯ˆ 2nd" if rank == 2 else "πŸ₯‰ 3rd" if rank == 3 else f"#{rank}"
1179
+
1180
+ report += f"| {model_display} | {teacher} | {overall_metrics.get('ndcg@10', 0):.4f} | {overall_metrics.get('mrr', 0):.4f} | {overall_metrics.get('recall@5', 0):.4f} | {status} |\n"
1181
+
1182
+ report += """
1183
+
1184
+ ### Key Findings
1185
+
1186
+ """
1187
+
1188
+ if simplified_models and len(simplified_models) > 1:
1189
+ best_model = simplified_models_sorted[0]
1190
+ worst_model = simplified_models_sorted[-1]
1191
+ best_score = best_model.get("overall", {}).get("ndcg@10", 0)
1192
+ worst_score = worst_model.get("overall", {}).get("ndcg@10", 0)
1193
+
1194
+ report += f"""
1195
+ - **Best Teacher Model**: {best_model["model_name"]} (NDCG@10: {best_score:.4f})
1196
+ - **Least Effective Teacher**: {worst_model["model_name"]} (NDCG@10: {worst_score:.4f})
1197
+ - **Performance Range**: {((best_score - worst_score) / best_score * 100):.1f}% difference between best and worst
1198
+ - **Average Performance**: {sum(r.get("overall", {}).get("ndcg@10", 0) for r in simplified_models) / len(simplified_models):.4f} NDCG@10
1199
+ """
1200
+
1201
+ # Add radar charts section
1202
+ report += """
1203
+
1204
+ ## 🎯 Language Performance Radar Charts
1205
+
1206
+ ### Best Model vs Peer Models Comparison
1207
+
1208
+ """
1209
+ if comparative_radar_chart:
1210
+ report += f"![Comparative Radar Chart]({comparative_radar_chart})\n\n"
1211
+ report += "*Comparative view showing how the best simplified distillation model performs against top peer models across programming languages.*\n\n"
1212
+
1213
+ # Add individual radar charts for all simplified models
1214
+ if individual_radar_charts:
1215
+ report += "### Individual Model Performance by Language\n\n"
1216
+ for chart_model_name, chart_path in individual_radar_charts.items():
1217
+ # Extract teacher name for cleaner display
1218
+ teacher = "Unknown"
1219
+ if "all_MiniLM_L6_v2" in chart_model_name:
1220
+ teacher = "all-MiniLM-L6-v2"
1221
+ elif "codebert_base" in chart_model_name:
1222
+ teacher = "codebert-base"
1223
+ elif "graphcodebert_base" in chart_model_name:
1224
+ teacher = "graphcodebert-base"
1225
+ elif "gte_Qwen2_7B_instruct" in chart_model_name:
1226
+ teacher = "gte-Qwen2-7B-instruct"
1227
+ elif "all_mpnet_base_v2" in chart_model_name:
1228
+ teacher = "all-mpnet-base-v2"
1229
+
1230
+ report += f"#### {chart_model_name} (Teacher: {teacher})\n\n"
1231
+ report += f"![{chart_model_name} Radar Chart]({chart_path})\n\n"
1232
+
1233
+ report += f"""
1234
+
1235
+ ## πŸ† Peer Model Comparison
1236
+
1237
+ ![Peer Comparison]({peer_chart})
1238
+
1239
+ *Comparison with established code-specialized embedding models using actual evaluation results.*
1240
+
1241
+ ### Complete Model Ranking
1242
+
1243
+ """
1244
+
1245
+ # Add comprehensive ranking table
1246
+ if self.comparison_df is not None and len(self.comparison_df) > 0:
1247
+ report += "| Rank | Model | Type | NDCG@10 | MRR | Recall@5 |\n"
1248
+ report += "|------|-------|------|---------|-----|----------|\n"
1249
+
1250
+ for rank in range(len(self.comparison_df)):
1251
+ row_data = self.comparison_df.iloc[rank]
1252
+ model_name_display = str(row_data["Model"])
1253
+
1254
+ # Determine model type
1255
+ if (
1256
+ "code_model2vec" in model_name_display.lower()
1257
+ or "distilled" in model_name_display.lower()
1258
+ or "(ours)" in model_name_display.lower()
1259
+ ):
1260
+ model_type = "**πŸ”₯ Simplified Distillation**"
1261
+ elif any(code_term in model_name_display.lower() for code_term in ["codebert", "graphcode", "codet5"]):
1262
+ model_type = "Code-Specific"
1263
+ elif "potion" in model_name_display.lower():
1264
+ model_type = "Model2Vec"
1265
+ else:
1266
+ model_type = "General"
1267
+
1268
+ report += f"| {rank + 1} | {model_name_display} | {model_type} | {row_data['NDCG@10']:.4f} | {row_data['MRR']:.4f} | {row_data['Recall@5']:.4f} |\n"
1269
+
1270
+ report += f"""
1271
+
1272
+ ## πŸ“ˆ Performance Analysis
1273
+
1274
+ ### Multi-Model Comparison Charts
1275
+
1276
+ ![Model Comparison]({comparison_chart})
1277
+
1278
+ *Comprehensive comparison across all evaluation metrics.*
1279
+
1280
+ ### Language Performance Analysis
1281
+
1282
+ ![Language Heatmap]({heatmap_chart})
1283
+
1284
+ *Performance heatmap showing how different models perform across programming languages.*
1285
+
1286
+ ### Efficiency Analysis
1287
+
1288
+ ![Efficiency Analysis]({efficiency_chart})
1289
+
1290
+ *Performance vs model size analysis showing the efficiency benefits of distillation.*
1291
+
1292
+ """
1293
+
1294
+ # Add benchmark analysis if available
1295
+ if self.benchmark_results:
1296
+ report += f"""
1297
+
1298
+ ## ⚑ Operational Performance Analysis
1299
+
1300
+ ![Benchmark Performance]({benchmark_chart})
1301
+
1302
+ *Comprehensive performance benchmarking across multiple operational metrics.*
1303
+
1304
+ ### Performance Scaling Analysis
1305
+
1306
+ ![Batch Size Scaling]({batch_scaling_chart})
1307
+
1308
+ *How performance scales with different batch sizes for optimal throughput.*
1309
+
1310
+ ![Memory Scaling]({memory_scaling_chart})
1311
+
1312
+ *Memory usage patterns across different batch sizes.*
1313
+
1314
+ """
1315
+
1316
+ # Add detailed language analysis
1317
+ report += """
1318
+
1319
+ ## πŸ” Language-Specific Analysis
1320
+
1321
+ ### Performance by Programming Language
1322
+
1323
+ """
1324
+
1325
+ if language_scores:
1326
+ report += "| Language | Best Model Performance | Average Performance | Language Difficulty |\n"
1327
+ report += "|----------|------------------------|--------------------|--------------------||\n"
1328
+
1329
+ for lang in sorted(language_scores.keys()):
1330
+ # Find best performance for this language across all models
1331
+ lang_performances = []
1332
+ for result in self.results:
1333
+ lang_data = result.get("languages", {}).get(lang.lower(), {})
1334
+ if lang_data:
1335
+ lang_performances.append(lang_data.get("metrics", {}).get("ndcg@10", 0))
1336
+
1337
+ if lang_performances:
1338
+ best_lang_perf = max(lang_performances)
1339
+ avg_lang_perf = sum(lang_performances) / len(lang_performances)
1340
+ difficulty = "Easy" if avg_lang_perf > 0.3 else "Medium" if avg_lang_perf > 0.2 else "Hard"
1341
+
1342
+ report += f"| {lang} | {best_lang_perf:.4f} | {avg_lang_perf:.4f} | {difficulty} |\n"
1343
+
1344
+ report += """
1345
+
1346
+ ## 🎯 Conclusions and Recommendations
1347
+
1348
+ ### Teacher Model Analysis
1349
+
1350
+ Based on the evaluation results across all simplified distillation models:
1351
+
1352
+ """
1353
+
1354
+ if simplified_models and len(simplified_models) > 1:
1355
+ # Analyze which teacher models work best
1356
+ teacher_performance = {}
1357
+ for result in simplified_models:
1358
+ model_name = result["model_name"]
1359
+ score = result.get("overall", {}).get("ndcg@10", 0)
1360
+
1361
+ if "all_MiniLM_L6_v2" in model_name:
1362
+ teacher_performance["all-MiniLM-L6-v2"] = score
1363
+ elif "codebert_base" in model_name:
1364
+ teacher_performance["codebert-base"] = score
1365
+ elif "graphcodebert_base" in model_name:
1366
+ teacher_performance["graphcodebert-base"] = score
1367
+ elif "gte_Qwen2_7B_instruct" in model_name:
1368
+ teacher_performance["gte-Qwen2-7B-instruct"] = score
1369
+ elif "all_mpnet_base_v2" in model_name:
1370
+ teacher_performance["all-mpnet-base-v2"] = score
1371
+
1372
+ if teacher_performance:
1373
+ best_teacher = max(teacher_performance.items(), key=lambda x: x[1])
1374
+ worst_teacher = min(teacher_performance.items(), key=lambda x: x[1])
1375
+
1376
+ report += f"""
1377
+ 1. **Best Teacher Model**: {best_teacher[0]} (NDCG@10: {best_teacher[1]:.4f})
1378
+ 2. **Least Effective Teacher**: {worst_teacher[0]} (NDCG@10: {worst_teacher[1]:.4f})
1379
+ 3. **Teacher Model Impact**: Choice of teacher model affects performance by {((best_teacher[1] - worst_teacher[1]) / best_teacher[1] * 100):.1f}%
1380
+
1381
+ ### Recommendations
1382
+
1383
+ - **For Production**: Use {best_teacher[0]} as teacher model for best performance
1384
+ - **For Efficiency**: Model2Vec distillation provides significant size reduction with competitive performance
1385
+ - **For Code Tasks**: Specialized models consistently outperform general-purpose models
1386
+ """
1387
+
1388
+ report += f"""
1389
+
1390
+ ## πŸ“„ Methodology
1391
+
1392
+ ### Evaluation Protocol
1393
+ - **Dataset**: CodeSearchNet test sets for 6 programming languages
1394
+ - **Metrics**: NDCG@k, MRR, Recall@k following CodeSearchNet methodology
1395
+ - **Query Format**: Natural language documentation strings
1396
+ - **Corpus Format**: Function code strings
1397
+ - **Evaluation**: Retrieval of correct code for each documentation query
1398
+
1399
+ ### Teacher Models Tested
1400
+ - sentence-transformers/all-MiniLM-L6-v2 (proven baseline)
1401
+ - microsoft/codebert-base (code-specialized)
1402
+ - microsoft/graphcodebert-base (graph-aware code model)
1403
+ - Alibaba-NLP/gte-Qwen2-7B-instruct (large instruction model)
1404
+ - sentence-transformers/all-mpnet-base-v2 (general purpose)
1405
+
1406
+ ### Distillation Method
1407
+ - **Technique**: Model2Vec static embedding generation
1408
+ - **Parameters**: PCA dims=256, SIF coefficient=1e-3, Zipf weighting=True
1409
+ - **Training Data**: CodeSearchNet comment-code pairs
1410
+ - **Languages**: Python, JavaScript, Java, PHP, Ruby, Go
1411
+
1412
+ ---
1413
+
1414
+ *Report generated on {time.strftime("%Y-%m-%d %H:%M:%S")} using automated analysis pipeline.*
1415
+ *For questions about methodology or results, please refer to the CodeSearchNet documentation.*
1416
+ """
1417
+
1418
+ return report
1419
+
1420
+ def export_results(self, output_file: str) -> None:
1421
+ """Export results to CSV format."""
1422
+ if self.comparison_df is not None:
1423
+ self.comparison_df.to_csv(output_file, index=False)
1424
+ logger.info(f"Results exported to {output_file}")
1425
+
1426
+
1427
+ def main() -> None:
1428
+ """Main analysis function."""
1429
+ parser = argparse.ArgumentParser(description="Analyze CodeSearchNet evaluation results and performance benchmarks")
1430
+ parser.add_argument("--results-dir", default=DEFAULT_EVALUATION_DIR, help="Evaluation results directory")
1431
+ parser.add_argument("--benchmark-dir", default=DEFAULT_BENCHMARK_DIR, help="Benchmark results directory")
1432
+ parser.add_argument("--model-name", default="gte_qwen2_m2v_code (Ours)", help="Model name for report")
1433
+ parser.add_argument("--output", default="REPORT.md", help="Output report file")
1434
+ parser.add_argument("--export-csv", help="Export comparison results to CSV")
1435
+
1436
+ args = parser.parse_args()
1437
+
1438
+ logger.info("Starting CodeSearchNet Analysis with Benchmark Integration")
1439
+ logger.info("=" * 60)
1440
+
1441
+ # Setup output directories
1442
+ output_dir, images_dir, reports_dir = setup_directories()
1443
+
1444
+ # Initialize analyzer with local directories
1445
+ analyzer = CodeSearchNetAnalyzer(
1446
+ results_dir=args.results_dir,
1447
+ benchmark_dir=args.benchmark_dir,
1448
+ images_dir=images_dir,
1449
+ )
1450
+
1451
+ # Load results (this will also load benchmark results)
1452
+ analyzer.load_results()
1453
+
1454
+ if not analyzer.results:
1455
+ logger.error("No evaluation results found! Please run evaluation first.")
1456
+ return
1457
+
1458
+ # Print summary (includes both evaluation and benchmark summaries)
1459
+ analyzer.print_summary()
1460
+ analyzer.analyze_language_performance()
1461
+
1462
+ # Analyze benchmark performance if available
1463
+ if analyzer.benchmark_results:
1464
+ analyzer.analyze_benchmark_performance()
1465
+ else:
1466
+ logger.warning("No benchmark results found. Run benchmark.py first for complete analysis.")
1467
+
1468
+ # Generate comprehensive report with benchmark integration
1469
+ logger.info("Generating comprehensive report with benchmark data...")
1470
+ report = analyzer.generate_comprehensive_report(args.model_name)
1471
+
1472
+ # Save report
1473
+ report_path = Path(args.output)
1474
+ with report_path.open("w") as f:
1475
+ f.write(report)
1476
+
1477
+ # Export CSV if requested
1478
+ if args.export_csv:
1479
+ analyzer.export_results(args.export_csv)
1480
+
1481
+ # Export benchmark CSV if available
1482
+ if analyzer.benchmark_df is not None and not analyzer.benchmark_df.empty:
1483
+ benchmark_csv = report_path.parent / f"{args.model_name}_benchmark_comparison.csv"
1484
+ analyzer.benchmark_df.to_csv(benchmark_csv, index=False)
1485
+ logger.info(f"πŸ“Š Benchmark comparison saved to: {benchmark_csv}")
1486
+
1487
+ logger.info("βœ… CodeSearchNet analysis with benchmarks complete!")
1488
+ logger.info(f"πŸ“Š Report saved to: {report_path}")
1489
+ logger.info(f"πŸ–ΌοΈ Charts saved to: {images_dir}")
1490
+
1491
+
1492
+ if __name__ == "__main__":
1493
+ import argparse
1494
+
1495
+ main()
src/distiller/beam_utils.py ADDED
@@ -0,0 +1,753 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Beam Cloud Utilities for Model Distillation and Evaluation.
3
+
4
+ This module provides comprehensive utilities for managing Beam volumes, checkpoints,
5
+ and file operations across distillation, evaluation, and analysis workflows.
6
+
7
+ Features:
8
+ - Volume management (direct file operations when mounted)
9
+ - Checkpoint operations (save, load, cleanup, resume)
10
+ - File transfer utilities (copy, move, sync)
11
+ - Evaluation result management
12
+ - Model artifact handling
13
+ - Distributed storage optimization
14
+ """
15
+
16
+ import json
17
+ import logging
18
+ import shutil
19
+ import time
20
+ from pathlib import Path
21
+ from typing import Any
22
+
23
+ # Configure logging
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ class BeamVolumeManager:
28
+ """Manager for Beam distributed storage volumes using direct file operations."""
29
+
30
+ def __init__(self, volume_name: str, mount_path: str = "./data") -> None:
31
+ """
32
+ Initialize Beam Volume Manager.
33
+
34
+ Args:
35
+ volume_name: Name of the Beam volume
36
+ mount_path: Local mount path for the volume (should match Beam function mount path)
37
+ """
38
+ self.volume_name = volume_name
39
+ self.mount_path = Path(mount_path)
40
+ self.mount_path.mkdir(parents=True, exist_ok=True)
41
+
42
+ def exists(self) -> bool:
43
+ """Check if the volume mount path exists."""
44
+ return self.mount_path.exists()
45
+
46
+ def list_contents(self, subpath: str = "") -> list[dict[str, Any]]:
47
+ """List contents of the volume directory."""
48
+ try:
49
+ target_path = self.mount_path / subpath if subpath else self.mount_path
50
+ if not target_path.exists():
51
+ logger.warning(f"⚠️ Path does not exist: {target_path}")
52
+ return []
53
+
54
+ contents: list[dict[str, Any]] = []
55
+ for item in target_path.iterdir():
56
+ stat = item.stat()
57
+ contents.append(
58
+ {
59
+ "name": item.name,
60
+ "size": f"{stat.st_size / (1024 * 1024):.2f}MB" if item.is_file() else "0MB",
61
+ "modified": time.ctime(stat.st_mtime),
62
+ "is_dir": item.is_dir(),
63
+ "path": str(item.relative_to(self.mount_path)),
64
+ }
65
+ )
66
+
67
+ return sorted(contents, key=lambda x: (not x["is_dir"], x["name"]))
68
+
69
+ except Exception:
70
+ logger.exception("❌ Error listing contents")
71
+ return []
72
+
73
+ def copy_file(self, src: str | Path, dst: str | Path) -> bool:
74
+ """Copy a file within the volume."""
75
+ try:
76
+ src_path = self.mount_path / src if not Path(src).is_absolute() else Path(src)
77
+ dst_path = self.mount_path / dst if not Path(dst).is_absolute() else Path(dst)
78
+
79
+ dst_path.parent.mkdir(parents=True, exist_ok=True)
80
+ shutil.copy2(src_path, dst_path)
81
+
82
+ logger.info(f"πŸ“„ Copied {src_path.name} to {dst_path}")
83
+ return True
84
+
85
+ except Exception:
86
+ logger.exception("❌ Error copying file")
87
+ return False
88
+
89
+ def copy_directory(self, src: str | Path, dst: str | Path) -> bool:
90
+ """Copy a directory within the volume."""
91
+ try:
92
+ src_path = self.mount_path / src if not Path(src).is_absolute() else Path(src)
93
+ dst_path = self.mount_path / dst if not Path(dst).is_absolute() else Path(dst)
94
+
95
+ if dst_path.exists():
96
+ shutil.rmtree(dst_path)
97
+
98
+ shutil.copytree(src_path, dst_path)
99
+
100
+ logger.info(f"πŸ“ Copied directory {src_path.name} to {dst_path}")
101
+ return True
102
+
103
+ except Exception:
104
+ logger.exception("❌ Error copying directory")
105
+ return False
106
+
107
+ def move_file(self, src: str | Path, dst: str | Path) -> bool:
108
+ """Move a file within the volume."""
109
+ try:
110
+ src_path = self.mount_path / src if not Path(src).is_absolute() else Path(src)
111
+ dst_path = self.mount_path / dst if not Path(dst).is_absolute() else Path(dst)
112
+
113
+ dst_path.parent.mkdir(parents=True, exist_ok=True)
114
+ shutil.move(str(src_path), str(dst_path))
115
+
116
+ logger.info(f"➑️ Moved {src_path.name} to {dst_path}")
117
+ return True
118
+
119
+ except Exception:
120
+ logger.exception("❌ Error moving file")
121
+ return False
122
+
123
+ def remove_file(self, file_path: str | Path) -> bool:
124
+ """Remove a file from the volume."""
125
+ try:
126
+ target_path = self.mount_path / file_path if not Path(file_path).is_absolute() else Path(file_path)
127
+
128
+ if target_path.exists():
129
+ if target_path.is_file():
130
+ target_path.unlink()
131
+ logger.info(f"πŸ—‘οΈ Removed file: {target_path.name}")
132
+ else:
133
+ logger.warning(f"⚠️ Path is not a file: {target_path}")
134
+ return False
135
+ return True
136
+ logger.warning(f"⚠️ File does not exist: {target_path}")
137
+ return False
138
+
139
+ except Exception:
140
+ logger.exception("❌ Error removing file")
141
+ return False
142
+
143
+ def remove_directory(self, dir_path: str | Path) -> bool:
144
+ """Remove a directory from the volume."""
145
+ try:
146
+ target_path = self.mount_path / dir_path if not Path(dir_path).is_absolute() else Path(dir_path)
147
+
148
+ if target_path.exists() and target_path.is_dir():
149
+ shutil.rmtree(target_path)
150
+ logger.info(f"πŸ—‘οΈ Removed directory: {target_path.name}")
151
+ return True
152
+ logger.warning(f"⚠️ Directory does not exist: {target_path}")
153
+ return False
154
+
155
+ except Exception:
156
+ logger.exception("❌ Error removing directory")
157
+ return False
158
+
159
+ def cleanup_old_files(self, pattern: str = "*", older_than_days: int = 7, subpath: str = "") -> list[str]:
160
+ """Clean up old files in the volume based on age and pattern."""
161
+ try:
162
+ target_path = self.mount_path / subpath if subpath else self.mount_path
163
+ if not target_path.exists():
164
+ return []
165
+
166
+ cutoff_time = time.time() - (older_than_days * 24 * 3600)
167
+ removed_files: list[str] = []
168
+
169
+ for item in target_path.rglob(pattern):
170
+ if item.is_file() and item.stat().st_mtime < cutoff_time:
171
+ try:
172
+ item.unlink()
173
+ removed_files.append(str(item.relative_to(self.mount_path)))
174
+ logger.info(f"🧹 Removed old file: {item.name}")
175
+ except Exception as e:
176
+ logger.warning(f"⚠️ Could not remove {item.name}: {e}")
177
+
178
+ if removed_files:
179
+ logger.info(f"🧹 Cleaned up {len(removed_files)} old files")
180
+
181
+ return removed_files
182
+
183
+ except Exception:
184
+ logger.exception("❌ Error during cleanup")
185
+ return []
186
+
187
+ def get_size(self, subpath: str = "") -> int:
188
+ """Get total size of volume or subpath in bytes."""
189
+ try:
190
+ target_path = self.mount_path / subpath if subpath else self.mount_path
191
+ if not target_path.exists():
192
+ return 0
193
+
194
+ total_size = 0
195
+ for item in target_path.rglob("*"):
196
+ if item.is_file():
197
+ total_size += item.stat().st_size
198
+
199
+ return total_size
200
+
201
+ except Exception:
202
+ logger.exception("❌ Error calculating size")
203
+ return 0
204
+
205
+
206
+ class BeamCheckpointManager:
207
+ """Manager for checkpoint operations on Beam volumes."""
208
+
209
+ def __init__(self, volume_manager: BeamVolumeManager, checkpoint_prefix: str = "checkpoints") -> None:
210
+ """
211
+ Initialize Checkpoint Manager.
212
+
213
+ Args:
214
+ volume_manager: BeamVolumeManager instance
215
+ checkpoint_prefix: Prefix for checkpoint files
216
+ """
217
+ self.volume = volume_manager
218
+ self.checkpoint_prefix = checkpoint_prefix
219
+ self.checkpoint_dir = self.volume.mount_path / checkpoint_prefix
220
+ self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
221
+
222
+ def save_checkpoint(self, stage: str, data: dict[str, Any], step: int = 0) -> bool:
223
+ """Save checkpoint to volume."""
224
+ try:
225
+ checkpoint_filename = f"{self.checkpoint_prefix}_{stage}_step_{step}.json"
226
+ checkpoint_path = self.checkpoint_dir / checkpoint_filename
227
+
228
+ with checkpoint_path.open("w") as f:
229
+ json.dump(data, f, indent=2, default=str)
230
+
231
+ logger.info(f"πŸ’Ύ Saved checkpoint: {stage} step {step}")
232
+ return True
233
+
234
+ except Exception:
235
+ logger.exception("❌ Error saving checkpoint")
236
+ return False
237
+
238
+ def load_checkpoint(self, stage: str, step: int = 0) -> dict[str, Any] | None:
239
+ """Load checkpoint from volume."""
240
+ try:
241
+ checkpoint_filename = f"{self.checkpoint_prefix}_{stage}_step_{step}.json"
242
+ checkpoint_path = self.checkpoint_dir / checkpoint_filename
243
+
244
+ if checkpoint_path.exists():
245
+ with checkpoint_path.open("r") as f:
246
+ data = json.load(f)
247
+ logger.info(f"πŸ“‚ Loaded checkpoint: {stage} step {step}")
248
+ return data
249
+
250
+ logger.info(f"Info: No checkpoint found: {stage} step {step}")
251
+ return None
252
+
253
+ except Exception:
254
+ logger.exception("❌ Error loading checkpoint")
255
+ return None
256
+
257
+ def get_latest_checkpoint(self, stage: str) -> tuple[int, dict[str, Any]] | None:
258
+ """Get the latest checkpoint for a stage."""
259
+ try:
260
+ # Find checkpoint files for this stage
261
+ pattern = f"{self.checkpoint_prefix}_{stage}_step_*.json"
262
+ stage_checkpoints: list[tuple[int, Path]] = []
263
+
264
+ for checkpoint_file in self.checkpoint_dir.glob(pattern):
265
+ try:
266
+ # Extract step number from filename
267
+ step_str = checkpoint_file.stem.replace(f"{self.checkpoint_prefix}_{stage}_step_", "")
268
+ step = int(step_str)
269
+ stage_checkpoints.append((step, checkpoint_file))
270
+ except ValueError:
271
+ continue
272
+
273
+ if not stage_checkpoints:
274
+ logger.info(f"Info: No checkpoints found for stage: {stage}")
275
+ return None
276
+
277
+ # Get the latest checkpoint
278
+ latest_step, latest_file = max(stage_checkpoints, key=lambda x: x[0])
279
+
280
+ # Load the latest checkpoint
281
+ with latest_file.open("r") as f:
282
+ data = json.load(f)
283
+ logger.info(f"πŸ“‚ Found latest checkpoint: {stage} step {latest_step}")
284
+ return latest_step, data
285
+
286
+ except Exception:
287
+ logger.exception("❌ Error finding latest checkpoint")
288
+ return None
289
+
290
+ def cleanup_old_checkpoints(self, stage: str, keep_latest: int = 3) -> list[str]:
291
+ """Clean up old checkpoints, keeping only the latest N."""
292
+ try:
293
+ # Find checkpoint files for this stage
294
+ pattern = f"{self.checkpoint_prefix}_{stage}_step_*.json"
295
+ stage_checkpoints: list[tuple[int, Path]] = []
296
+
297
+ for checkpoint_file in self.checkpoint_dir.glob(pattern):
298
+ try:
299
+ step_str = checkpoint_file.stem.replace(f"{self.checkpoint_prefix}_{stage}_step_", "")
300
+ step = int(step_str)
301
+ stage_checkpoints.append((step, checkpoint_file))
302
+ except ValueError:
303
+ continue
304
+
305
+ # Sort by step number (descending)
306
+ stage_checkpoints.sort(key=lambda x: x[0], reverse=True)
307
+
308
+ # Remove old checkpoints
309
+ removed_files: list[str] = []
310
+ if len(stage_checkpoints) > keep_latest:
311
+ for _step, checkpoint_file in stage_checkpoints[keep_latest:]:
312
+ try:
313
+ checkpoint_file.unlink()
314
+ removed_files.append(checkpoint_file.name)
315
+ logger.info(f"🧹 Removed old checkpoint: {checkpoint_file.name}")
316
+ except Exception as e:
317
+ logger.warning(f"⚠️ Could not remove {checkpoint_file.name}: {e}")
318
+
319
+ if removed_files:
320
+ logger.info(f"🧹 Cleaned up {len(removed_files)} old checkpoints for {stage}")
321
+
322
+ return removed_files
323
+
324
+ except Exception:
325
+ logger.exception("❌ Error cleaning up checkpoints")
326
+ return []
327
+
328
+ def list_checkpoints(self, stage: str | None = None) -> list[dict[str, Any]]:
329
+ """List all checkpoints, optionally filtered by stage."""
330
+ try:
331
+ checkpoints: list[dict[str, Any]] = []
332
+ pattern = f"{self.checkpoint_prefix}_*.json"
333
+
334
+ for checkpoint_file in self.checkpoint_dir.glob(pattern):
335
+ # Parse checkpoint info
336
+ name_parts = checkpoint_file.stem.split("_")
337
+ if len(name_parts) >= 4:
338
+ checkpoint_stage = name_parts[1]
339
+ try:
340
+ step = int(name_parts[3])
341
+ except ValueError:
342
+ step = 0
343
+
344
+ if stage is None or checkpoint_stage == stage:
345
+ stat = checkpoint_file.stat()
346
+ checkpoints.append(
347
+ {
348
+ "stage": checkpoint_stage,
349
+ "step": step,
350
+ "filename": checkpoint_file.name,
351
+ "size": f"{stat.st_size / 1024:.1f}KB",
352
+ "modified": time.ctime(stat.st_mtime),
353
+ }
354
+ )
355
+
356
+ return sorted(checkpoints, key=lambda x: (x["stage"], x["step"]))
357
+
358
+ except Exception:
359
+ logger.exception("❌ Error listing checkpoints")
360
+ return []
361
+
362
+
363
+ class BeamModelManager:
364
+ """Manager for model artifacts on Beam volumes."""
365
+
366
+ def __init__(self, volume_manager: BeamVolumeManager, model_prefix: str = "models") -> None:
367
+ """
368
+ Initialize Model Manager.
369
+
370
+ Args:
371
+ volume_manager: BeamVolumeManager instance
372
+ model_prefix: Prefix for model files
373
+ """
374
+ self.volume = volume_manager
375
+ self.model_prefix = model_prefix
376
+ self.model_dir = self.volume.mount_path / model_prefix
377
+ self.model_dir.mkdir(parents=True, exist_ok=True)
378
+
379
+ def save_model(self, model_name: str, local_model_path: str | Path) -> bool:
380
+ """Save model to Beam volume."""
381
+ try:
382
+ local_path = Path(local_model_path)
383
+ if not local_path.exists():
384
+ logger.error(f"❌ Model path does not exist: {local_path}")
385
+ return False
386
+
387
+ model_dest = self.model_dir / model_name
388
+
389
+ if local_path.is_dir():
390
+ # Copy entire directory
391
+ if model_dest.exists():
392
+ shutil.rmtree(model_dest)
393
+ shutil.copytree(local_path, model_dest)
394
+ logger.info(f"πŸ’Ύ Saved model directory {model_name}")
395
+ else:
396
+ # Copy single file
397
+ model_dest.mkdir(exist_ok=True)
398
+ shutil.copy2(local_path, model_dest / local_path.name)
399
+ logger.info(f"πŸ’Ύ Saved model file {model_name}")
400
+
401
+ return True
402
+
403
+ except Exception:
404
+ logger.exception("❌ Error saving model")
405
+ return False
406
+
407
+ def load_model(self, model_name: str, local_model_path: str | Path = "./models") -> bool:
408
+ """Load model from Beam volume."""
409
+ try:
410
+ local_path = Path(local_model_path)
411
+ local_path.mkdir(parents=True, exist_ok=True)
412
+
413
+ model_source = self.model_dir / model_name
414
+ model_dest = local_path / model_name
415
+
416
+ if not model_source.exists():
417
+ logger.error(f"❌ Model does not exist: {model_name}")
418
+ return False
419
+
420
+ if model_dest.exists():
421
+ if model_dest.is_dir():
422
+ shutil.rmtree(model_dest)
423
+ else:
424
+ model_dest.unlink()
425
+
426
+ if model_source.is_dir():
427
+ shutil.copytree(model_source, model_dest)
428
+ else:
429
+ shutil.copy2(model_source, model_dest)
430
+
431
+ logger.info(f"πŸ“₯ Loaded model {model_name}")
432
+ return True
433
+
434
+ except Exception:
435
+ logger.exception("❌ Error loading model")
436
+ return False
437
+
438
+ def list_models(self) -> list[dict[str, str]]:
439
+ """List all models in the volume."""
440
+ try:
441
+ models: list[dict[str, str]] = []
442
+
443
+ if not self.model_dir.exists():
444
+ return models
445
+
446
+ for item in self.model_dir.iterdir():
447
+ if item.is_dir():
448
+ stat = item.stat()
449
+ # Calculate directory size
450
+ total_size = sum(f.stat().st_size for f in item.rglob("*") if f.is_file())
451
+
452
+ models.append(
453
+ {
454
+ "name": item.name,
455
+ "size": f"{total_size / (1024 * 1024):.1f}MB",
456
+ "modified": time.ctime(stat.st_mtime),
457
+ }
458
+ )
459
+
460
+ return sorted(models, key=lambda x: x["name"])
461
+
462
+ except Exception:
463
+ logger.exception("❌ Error listing models")
464
+ return []
465
+
466
+ def remove_model(self, model_name: str) -> bool:
467
+ """Remove a model from the volume."""
468
+ try:
469
+ model_path = self.model_dir / model_name
470
+
471
+ if model_path.exists():
472
+ if model_path.is_dir():
473
+ shutil.rmtree(model_path)
474
+ else:
475
+ model_path.unlink()
476
+ logger.info(f"πŸ—‘οΈ Removed model: {model_name}")
477
+ return True
478
+ logger.warning(f"⚠️ Model does not exist: {model_name}")
479
+ return False
480
+
481
+ except Exception:
482
+ logger.exception("❌ Error removing model")
483
+ return False
484
+
485
+
486
+ class BeamEvaluationManager:
487
+ """Manager for evaluation results on Beam volumes."""
488
+
489
+ def __init__(
490
+ self,
491
+ volume_manager: BeamVolumeManager,
492
+ results_prefix: str = "evaluation_results",
493
+ ) -> None:
494
+ """
495
+ Initialize Evaluation Manager.
496
+
497
+ Args:
498
+ volume_manager: BeamVolumeManager instance
499
+ results_prefix: Prefix for evaluation result files
500
+ """
501
+ self.volume = volume_manager
502
+ self.results_prefix = results_prefix
503
+ self.results_dir = self.volume.mount_path / results_prefix
504
+ self.results_dir.mkdir(parents=True, exist_ok=True)
505
+
506
+ def save_evaluation_results(
507
+ self, model_name: str, results: dict[str, Any], eval_type: str = "codesearchnet"
508
+ ) -> bool:
509
+ """Save evaluation results to Beam volume."""
510
+ try:
511
+ results_filename = f"{eval_type}_eval_{model_name.replace('/', '_')}.json"
512
+ results_path = self.results_dir / results_filename
513
+
514
+ with results_path.open("w") as f:
515
+ json.dump(results, f, indent=2, default=str)
516
+
517
+ logger.info(f"πŸ’Ύ Saved evaluation results for {model_name}")
518
+ return True
519
+
520
+ except Exception:
521
+ logger.exception("❌ Error saving evaluation results")
522
+ return False
523
+
524
+ def load_evaluation_results(self, model_name: str, eval_type: str = "codesearchnet") -> dict[str, Any] | None:
525
+ """Load evaluation results from Beam volume."""
526
+ try:
527
+ results_filename = f"{eval_type}_eval_{model_name.replace('/', '_')}.json"
528
+ results_path = self.results_dir / results_filename
529
+
530
+ if results_path.exists():
531
+ with results_path.open("r") as f:
532
+ results = json.load(f)
533
+ logger.info(f"πŸ“‚ Loaded evaluation results for {model_name}")
534
+ return results
535
+
536
+ logger.info(f"Info: No evaluation results found for {model_name}")
537
+ return None
538
+
539
+ except Exception:
540
+ logger.exception("❌ Error loading evaluation results")
541
+ return None
542
+
543
+ def list_evaluation_results(self, eval_type: str | None = None) -> list[dict[str, str]]:
544
+ """List all evaluation results."""
545
+ try:
546
+ results: list[dict[str, str]] = []
547
+
548
+ if not self.results_dir.exists():
549
+ return results
550
+
551
+ for item in self.results_dir.glob("*.json"):
552
+ # Parse evaluation info
553
+ if eval_type is None or item.name.startswith(f"{eval_type}_eval_"):
554
+ # Extract model name from filename
555
+ model_name = item.name.replace("_eval_", "_").replace(".json", "")
556
+ if eval_type:
557
+ model_name = model_name.replace(f"{eval_type}_", "")
558
+
559
+ stat = item.stat()
560
+ results.append(
561
+ {
562
+ "model_name": model_name,
563
+ "filename": item.name,
564
+ "size": f"{stat.st_size / 1024:.1f}KB",
565
+ "modified": time.ctime(stat.st_mtime),
566
+ }
567
+ )
568
+
569
+ return sorted(results, key=lambda x: x["model_name"])
570
+
571
+ except Exception:
572
+ logger.exception("❌ Error listing evaluation results")
573
+ return []
574
+
575
+ def remove_evaluation_results(self, model_name: str, eval_type: str = "codesearchnet") -> bool:
576
+ """Remove evaluation results from volume."""
577
+ try:
578
+ results_filename = f"{eval_type}_eval_{model_name.replace('/', '_')}.json"
579
+ results_path = self.results_dir / results_filename
580
+
581
+ if results_path.exists():
582
+ results_path.unlink()
583
+ logger.info(f"πŸ—‘οΈ Removed evaluation results for {model_name}")
584
+ return True
585
+ logger.warning(f"⚠️ Evaluation results do not exist for {model_name}")
586
+ return False
587
+
588
+ except Exception:
589
+ logger.exception("❌ Error removing evaluation results")
590
+ return False
591
+
592
+
593
+ def create_beam_utilities(
594
+ volume_name: str, mount_path: str = "./data"
595
+ ) -> tuple[BeamVolumeManager, BeamCheckpointManager, BeamModelManager, BeamEvaluationManager]:
596
+ """
597
+ Create a complete set of Beam utilities.
598
+
599
+ Args:
600
+ volume_name: Name of the Beam volume
601
+ mount_path: Local mount path for the volume
602
+
603
+ Returns:
604
+ Tuple of (volume_manager, checkpoint_manager, model_manager, evaluation_manager)
605
+ """
606
+ volume_manager = BeamVolumeManager(volume_name, mount_path)
607
+ checkpoint_manager = BeamCheckpointManager(volume_manager)
608
+ model_manager = BeamModelManager(volume_manager)
609
+ evaluation_manager = BeamEvaluationManager(volume_manager)
610
+
611
+ return volume_manager, checkpoint_manager, model_manager, evaluation_manager
612
+
613
+
614
+ def cleanup_beam_workspace(volume_name: str, mount_path: str = "./data", confirm: bool = False) -> bool:
615
+ """
616
+ Clean up entire Beam workspace including all data in the mounted volume.
617
+
618
+ Args:
619
+ volume_name: Name of the volume to clean up
620
+ mount_path: Mount path of the volume
621
+ confirm: If True, skip confirmation prompt
622
+
623
+ Returns:
624
+ True if cleanup successful, False otherwise
625
+ """
626
+ if not confirm:
627
+ response = input(f"⚠️ This will delete all data in volume '{volume_name}' at '{mount_path}'. Continue? (y/N): ")
628
+ if response.lower() != "y":
629
+ logger.info("Cleanup cancelled")
630
+ return False
631
+
632
+ try:
633
+ volume_manager = BeamVolumeManager(volume_name, mount_path)
634
+
635
+ if not volume_manager.exists():
636
+ logger.info(f"Volume mount path does not exist: {mount_path}")
637
+ return True
638
+
639
+ # List what will be deleted
640
+ contents = volume_manager.list_contents()
641
+ logger.info(f"πŸ—‘οΈ Will delete {len(contents)} items from volume '{volume_name}'")
642
+
643
+ # Delete all contents in the mount path
644
+ for item in volume_manager.mount_path.iterdir():
645
+ try:
646
+ if item.is_dir():
647
+ shutil.rmtree(item)
648
+ logger.info(f"πŸ—‘οΈ Removed directory: {item.name}")
649
+ else:
650
+ item.unlink()
651
+ logger.info(f"πŸ—‘οΈ Removed file: {item.name}")
652
+ except Exception as e:
653
+ logger.warning(f"⚠️ Could not remove {item.name}: {e}")
654
+
655
+ logger.info(f"βœ… Successfully cleaned up Beam workspace: {volume_name}")
656
+ return True
657
+
658
+ except Exception:
659
+ logger.exception("❌ Error during cleanup")
660
+ return False
661
+
662
+
663
+ def get_workspace_info(volume_name: str, mount_path: str = "./data") -> dict[str, Any]:
664
+ """
665
+ Get information about the Beam workspace.
666
+
667
+ Args:
668
+ volume_name: Name of the volume
669
+ mount_path: Mount path of the volume
670
+
671
+ Returns:
672
+ Dictionary with workspace information
673
+ """
674
+ try:
675
+ volume_manager = BeamVolumeManager(volume_name, mount_path)
676
+
677
+ if not volume_manager.exists():
678
+ return {
679
+ "volume_name": volume_name,
680
+ "mount_path": mount_path,
681
+ "exists": False,
682
+ "size": 0,
683
+ "contents": [],
684
+ }
685
+
686
+ contents = volume_manager.list_contents()
687
+ total_size = volume_manager.get_size()
688
+
689
+ return {
690
+ "volume_name": volume_name,
691
+ "mount_path": str(volume_manager.mount_path),
692
+ "exists": True,
693
+ "size": total_size,
694
+ "size_mb": f"{total_size / (1024 * 1024):.1f}MB",
695
+ "num_items": len(contents),
696
+ "contents": contents[:10], # First 10 items
697
+ }
698
+
699
+ except Exception:
700
+ logger.exception("❌ Error getting workspace info")
701
+ return {
702
+ "volume_name": volume_name,
703
+ "mount_path": mount_path,
704
+ "error": "Error occurred",
705
+ }
706
+
707
+
708
+ # Example usage functions
709
+ def example_distillation_workflow() -> None:
710
+ """Example of using Beam utilities for distillation workflow."""
711
+ volume_name = "gte_qwen2_m2v_code"
712
+ mount_path = "./gte_qwen2_m2v_code" # Should match Beam function mount path
713
+
714
+ # Create utilities
715
+ volume_mgr, checkpoint_mgr, model_mgr, eval_mgr = create_beam_utilities(volume_name, mount_path)
716
+
717
+ # Check if volume exists
718
+ if volume_mgr.exists():
719
+ logger.info(f"Volume {volume_name} is mounted at {mount_path}")
720
+ else:
721
+ logger.warning(f"Volume {volume_name} not found at {mount_path}")
722
+ return
723
+
724
+ # Save a checkpoint
725
+ checkpoint_data = {
726
+ "epoch": 1,
727
+ "loss": 0.25,
728
+ "model_state": "dummy_state",
729
+ "timestamp": time.time(),
730
+ }
731
+ checkpoint_mgr.save_checkpoint("training", checkpoint_data, step=1000)
732
+
733
+ # List checkpoints
734
+ checkpoints = checkpoint_mgr.list_checkpoints("training")
735
+ logger.info(f"Found {len(checkpoints)} training checkpoints")
736
+
737
+ # Save evaluation results
738
+ eval_results = {
739
+ "model_name": "gte_qwen2_m2v_code",
740
+ "overall": {"ndcg@10": 0.35, "mrr": 0.42},
741
+ "timestamp": time.time(),
742
+ }
743
+ eval_mgr.save_evaluation_results("gte_qwen2_m2v_code", eval_results)
744
+
745
+ # Get workspace info
746
+ info = get_workspace_info(volume_name, mount_path)
747
+ logger.info(f"Workspace info: {info}")
748
+
749
+
750
+ if __name__ == "__main__":
751
+ # Example usage
752
+ logging.basicConfig(level=logging.INFO)
753
+ example_distillation_workflow()
src/distiller/benchmark.py ADDED
@@ -0,0 +1,1181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Operational Performance Benchmarking for Embedding Models.
3
+
4
+ This module benchmarks embedding models on operational metrics like:
5
+ - Inference speed (latency and throughput)
6
+ - Memory efficiency (RAM and GPU usage)
7
+ - Model size and storage requirements
8
+ - Scalability with batch size
9
+ - CPU vs GPU performance
10
+ """
11
+
12
+ import gc
13
+ import json
14
+ import logging
15
+ import os
16
+ import time
17
+ from pathlib import Path
18
+ from typing import Any
19
+
20
+ import pandas as pd
21
+ import psutil
22
+ import torch
23
+ from beam import GpuType, Image, Volume, function
24
+ from sentence_transformers import SentenceTransformer
25
+
26
+ from .beam_utils import (
27
+ BeamCheckpointManager,
28
+ BeamEvaluationManager,
29
+ create_beam_utilities,
30
+ )
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+ # =============================================================================
35
+ # BEAM CONFIGURATION
36
+ # =============================================================================
37
+
38
+ GPU_NAME = GpuType.A100_40
39
+ VOLUME_NAME = "gte_qwen2_m2v_code" # Same volume as distill.py and evaluate.py
40
+ VOLUME_PATH = "./gte_qwen2_m2v_code" # Same mount path as distill.py and evaluate.py
41
+ BENCHMARK_RESULTS_DIR = "benchmark_results" # Subdirectory within volume
42
+ BENCHMARK_CACHE_DIR = "benchmark_cache" # Cache for models
43
+
44
+ IMAGE = Image(python_version="python3.12").add_python_packages(
45
+ [
46
+ "torch>=2.7.0",
47
+ "transformers>=4.40.0",
48
+ "datasets>=3.2.0",
49
+ "sentence-transformers>=4.1.0",
50
+ "model2vec[train]>=0.5.0",
51
+ "numpy>=1.26.4",
52
+ "scikit-learn>=1.6.1",
53
+ "pandas>=2.0.0",
54
+ "tqdm>=4.65.0",
55
+ "psutil>=5.9.0",
56
+ ]
57
+ )
58
+
59
+ # =============================================================================
60
+ # CONFIGURATION
61
+ # =============================================================================
62
+
63
+ DEFAULT_OUTPUT_DIR = "benchmark_results" # Local fallback directory
64
+
65
+ # Default models to benchmark (can be overridden via command line)
66
+ DEFAULT_BENCHMARK_MODELS = [
67
+ # Your distilled model (local files in Beam volume root)
68
+ "gte_qwen2_m2v_code", # This will be resolved to VOLUME_PATH in Beam
69
+ # Established Code Models
70
+ "sentence-transformers/all-MiniLM-L6-v2",
71
+ "microsoft/codebert-base",
72
+ "microsoft/graphcodebert-base",
73
+ "huggingface/CodeBERTa-small-v1",
74
+ "sentence-transformers/all-mpnet-base-v2",
75
+ "sentence-transformers/all-MiniLM-L12-v2",
76
+ # Model2Vec & Efficiency Models (Direct Competitors)
77
+ "minishlab/potion-base-8M",
78
+ "minishlab/potion-retrieval-32M",
79
+ # Small Transformer-Based Code Models
80
+ "Salesforce/codet5-base",
81
+ ]
82
+
83
+ # =============================================================================
84
+ # CHECKPOINT CONFIGURATION
85
+ # =============================================================================
86
+
87
+ # Prevent conflicts with other modules by using unique prefixes
88
+ BENCHMARK_CHECKPOINT_PREFIX = "benchmark_checkpoints"
89
+ MODEL_CACHE_PREFIX = "model_cache"
90
+
91
+ # Sample texts for benchmarking (various lengths)
92
+ BENCHMARK_TEXTS = {
93
+ "short": [
94
+ "def add(a, b): return a + b",
95
+ "function multiply(x, y) { return x * y; }",
96
+ "class Calculator { public int subtract(int a, int b) { return a - b; } }",
97
+ ]
98
+ * 100, # 300 short texts
99
+ "medium": [
100
+ "def fibonacci(n):\n if n <= 1:\n return n\n return fibonacci(n-1) + fibonacci(n-2)",
101
+ "function quickSort(arr) {\n if (arr.length <= 1) return arr;\n const pivot = arr[arr.length - 1];\n const left = [], right = [];\n for (let i = 0; i < arr.length - 1; i++) {\n if (arr[i] < pivot) left.push(arr[i]);\n else right.push(arr[i]);\n }\n return [...quickSort(left), pivot, ...quickSort(right)];\n}",
102
+ ]
103
+ * 50, # 100 medium texts
104
+ "long": [
105
+ """
106
+ def complex_algorithm(data, config):
107
+ '''
108
+ Complex data processing algorithm with multiple steps.
109
+
110
+ Args:
111
+ data: Input data structure
112
+ config: Configuration parameters
113
+
114
+ Returns:
115
+ Processed results
116
+ '''
117
+ results = []
118
+
119
+ # Step 1: Data validation
120
+ if not isinstance(data, (list, tuple)):
121
+ raise ValueError("Data must be list or tuple")
122
+
123
+ # Step 2: Preprocessing
124
+ processed_data = []
125
+ for item in data:
126
+ if config.get('normalize', False):
127
+ item = normalize_item(item)
128
+ if config.get('filter', False):
129
+ if not filter_item(item, config['filter_criteria']):
130
+ continue
131
+ processed_data.append(item)
132
+
133
+ # Step 3: Main processing
134
+ for item in processed_data:
135
+ result = process_item(item, config)
136
+ if result is not None:
137
+ results.append(result)
138
+
139
+ # Step 4: Post-processing
140
+ if config.get('sort', False):
141
+ results.sort(key=lambda x: x.get('score', 0), reverse=True)
142
+
143
+ return results
144
+ """.strip(),
145
+ ]
146
+ * 20, # 20 long texts
147
+ }
148
+
149
+
150
+ class PerformanceBenchmark:
151
+ """Comprehensive performance benchmarking for embedding models."""
152
+
153
+ def __init__(
154
+ self,
155
+ model_path: str,
156
+ model_name: str | None = None,
157
+ checkpoint_manager: BeamCheckpointManager | None = None,
158
+ eval_manager: BeamEvaluationManager | None = None,
159
+ ) -> None:
160
+ """Initialize benchmarker with model and optional Beam utilities."""
161
+ self.model_path = model_path
162
+ self.model_name = model_name or Path(model_path).name
163
+ self.model: SentenceTransformer | None = None
164
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
165
+ self.results: dict[str, Any] = {}
166
+ self.checkpoint_manager = checkpoint_manager
167
+ self.eval_manager = eval_manager
168
+
169
+ def load_model(self) -> None:
170
+ """Load the embedding model."""
171
+ logger.info(f"Loading model from {self.model_path}")
172
+ start_time = time.time()
173
+
174
+ try:
175
+ self.model = SentenceTransformer(self.model_path, device=self.device, trust_remote_code=True)
176
+ load_time = time.time() - start_time
177
+
178
+ logger.info(f"βœ… Model loaded in {load_time:.2f}s on {self.device}")
179
+ self.results["model_load_time"] = load_time
180
+
181
+ except Exception:
182
+ logger.exception("❌ Failed to load model")
183
+ raise
184
+
185
+ def measure_model_size(self) -> dict[str, float]:
186
+ """Measure model size metrics."""
187
+ logger.info("πŸ“ Measuring model size...")
188
+
189
+ size_metrics = {}
190
+
191
+ # Disk size - handle both local paths and HuggingFace models
192
+ try:
193
+ if Path(self.model_path).is_dir():
194
+ # Local directory - calculate size of model files only
195
+ model_extensions = {".safetensors", ".bin", ".json", ".txt", ".tokenizer"}
196
+ total_size = 0
197
+ model_dir = Path(self.model_path)
198
+
199
+ for file_path in model_dir.rglob("*"):
200
+ if file_path.is_file() and (
201
+ file_path.suffix.lower() in model_extensions
202
+ or file_path.name.lower() in {"config.json", "tokenizer.json", "modules.json", "README.md"}
203
+ ):
204
+ total_size += file_path.stat().st_size
205
+
206
+ size_metrics["disk_size_mb"] = total_size / (1024 * 1024)
207
+ elif Path(self.model_path).is_file():
208
+ # Single file
209
+ total_size = Path(self.model_path).stat().st_size
210
+ size_metrics["disk_size_mb"] = total_size / (1024 * 1024)
211
+ else:
212
+ # HuggingFace model - estimate from cache if available
213
+ from transformers import AutoConfig
214
+
215
+ try:
216
+ config = AutoConfig.from_pretrained(self.model_path)
217
+ # Estimate size based on parameters (rough approximation)
218
+ if hasattr(config, "hidden_size") and hasattr(config, "num_hidden_layers"):
219
+ # Rough estimation for transformer models
220
+ estimated_params = config.hidden_size * config.num_hidden_layers * 1000 # Very rough
221
+ size_metrics["disk_size_mb"] = estimated_params * 4 / (1024 * 1024) # 4 bytes per float32
222
+ else:
223
+ size_metrics["disk_size_mb"] = 0 # Unknown
224
+ except Exception:
225
+ logger.warning(f"Could not determine disk size for HuggingFace model: {self.model_path}")
226
+ size_metrics["disk_size_mb"] = 0 # Unknown
227
+ except Exception as e:
228
+ logger.warning(f"Could not determine disk size: {e}")
229
+ size_metrics["disk_size_mb"] = 0
230
+
231
+ # Model parameters (if accessible)
232
+ try:
233
+ if self.model is not None and hasattr(self.model, "modules"):
234
+ total_params = sum(p.numel() for p in self.model.parameters())
235
+ size_metrics["parameters_millions"] = total_params / 1_000_000
236
+
237
+ # Try to get embedding dimension from model config
238
+ try:
239
+ # Use the public modules() method instead of private _modules
240
+ modules = list(self.model.modules())
241
+ if len(modules) > 1: # modules[0] is usually the entire model, modules[1] is first submodule
242
+ first_module = modules[1]
243
+ if hasattr(first_module, "auto_model") and hasattr(first_module.auto_model, "config"):
244
+ config = first_module.auto_model.config
245
+ if hasattr(config, "hidden_size"):
246
+ size_metrics["embedding_dim"] = config.hidden_size
247
+ elif hasattr(config, "model_dim"):
248
+ size_metrics["embedding_dim"] = config.model_dim
249
+ except Exception as e:
250
+ logger.debug(
251
+ f"Could not extract embedding dimension from model config: {e}"
252
+ ) # Silently continue if this method fails
253
+
254
+ # For Model2Vec static models
255
+ elif self.model is not None and hasattr(self.model, "embedding"):
256
+ # Handle both tensor and numpy array embeddings
257
+ embedding = self.model.embedding
258
+ if hasattr(embedding, "shape"):
259
+ vocab_size, embedding_dim = embedding.shape # type: ignore[misc]
260
+ total_params = vocab_size * embedding_dim
261
+ size_metrics["parameters_millions"] = total_params / 1_000_000
262
+ size_metrics["vocab_size"] = vocab_size
263
+ size_metrics["embedding_dim"] = embedding_dim
264
+ else:
265
+ logger.warning("Could not determine embedding shape for Model2Vec model")
266
+
267
+ # Alternative method: get embedding dimension from a test encoding
268
+ if "embedding_dim" not in size_metrics and self.model is not None:
269
+ try:
270
+ test_embedding = self.model.encode(["test"], convert_to_tensor=False)
271
+ if hasattr(test_embedding, "shape") and len(test_embedding.shape) >= 2:
272
+ size_metrics["embedding_dim"] = test_embedding.shape[1]
273
+ elif (
274
+ isinstance(test_embedding, (list, tuple))
275
+ and len(test_embedding) > 0
276
+ and hasattr(test_embedding[0], "__len__")
277
+ ):
278
+ size_metrics["embedding_dim"] = len(test_embedding[0])
279
+ except Exception as e:
280
+ logger.warning(f"Could not determine embedding dimension: {e}")
281
+
282
+ except Exception as e:
283
+ logger.warning(f"Could not determine parameter count: {e}")
284
+
285
+ # Memory footprint
286
+ if self.device == "cuda" and torch.cuda.is_available():
287
+ torch.cuda.empty_cache()
288
+ size_metrics["gpu_memory_mb"] = torch.cuda.memory_allocated() / (1024 * 1024)
289
+
290
+ # RAM usage (approximate)
291
+ process = psutil.Process(os.getpid())
292
+ size_metrics["ram_usage_mb"] = process.memory_info().rss / (1024 * 1024)
293
+
294
+ self.results["size_metrics"] = size_metrics
295
+ return size_metrics
296
+
297
+ def benchmark_inference_speed(self, batch_sizes: list[int] | None = None) -> dict[str, Any]:
298
+ """Benchmark inference speed across different batch sizes."""
299
+ if batch_sizes is None:
300
+ batch_sizes = [1, 8, 16, 32, 64, 128]
301
+ logger.info("⚑ Benchmarking inference speed...")
302
+
303
+ if self.model is None:
304
+ self.load_model()
305
+
306
+ if self.model is None:
307
+ msg = "Failed to load model"
308
+ raise RuntimeError(msg)
309
+
310
+ speed_results = {}
311
+ text_lengths = ["short", "medium", "long"]
312
+
313
+ for text_length in text_lengths:
314
+ logger.info(f" πŸ“ Testing {text_length} texts...")
315
+ texts = BENCHMARK_TEXTS[text_length]
316
+
317
+ length_results = {}
318
+
319
+ for batch_size in batch_sizes:
320
+ if batch_size > len(texts):
321
+ continue
322
+
323
+ logger.info(f" πŸ”„ Batch size: {batch_size}")
324
+
325
+ # Prepare batch
326
+ batch_texts = texts[:batch_size]
327
+
328
+ # Warmup
329
+ if self.device == "cuda":
330
+ torch.cuda.synchronize()
331
+ _ = self.model.encode(batch_texts[: min(2, batch_size)], convert_to_tensor=False)
332
+
333
+ # Clear cache
334
+ if self.device == "cuda":
335
+ torch.cuda.empty_cache()
336
+ torch.cuda.synchronize()
337
+
338
+ # Measure inference time
339
+ start_time = time.perf_counter()
340
+
341
+ embeddings = self.model.encode(batch_texts, convert_to_tensor=False, show_progress_bar=False)
342
+
343
+ if self.device == "cuda":
344
+ torch.cuda.synchronize()
345
+
346
+ end_time = time.perf_counter()
347
+
348
+ # Calculate metrics
349
+ total_time = end_time - start_time
350
+ time_per_text = total_time / batch_size
351
+ texts_per_second = batch_size / total_time
352
+
353
+ # Estimate tokens (rough approximation)
354
+ avg_tokens = sum(len(text.split()) for text in batch_texts) / batch_size
355
+ total_tokens = avg_tokens * batch_size
356
+ tokens_per_second = total_tokens / total_time
357
+
358
+ length_results[f"batch_{batch_size}"] = {
359
+ "total_time_ms": total_time * 1000,
360
+ "time_per_text_ms": time_per_text * 1000,
361
+ "texts_per_second": texts_per_second,
362
+ "tokens_per_second": tokens_per_second,
363
+ "avg_tokens_per_text": avg_tokens,
364
+ "embedding_shape": embeddings.shape
365
+ if hasattr(embeddings, "shape")
366
+ else f"({len(embeddings)}, {len(embeddings[0]) if embeddings else 0})",
367
+ }
368
+
369
+ speed_results[text_length] = length_results
370
+
371
+ self.results["speed_benchmarks"] = speed_results
372
+ return speed_results
373
+
374
+ def benchmark_memory_scaling(self, batch_sizes: list[int] | None = None) -> dict[str, Any]:
375
+ """Benchmark memory usage across batch sizes."""
376
+ if batch_sizes is None:
377
+ batch_sizes = [1, 8, 16, 32, 64, 128, 256]
378
+ logger.info("πŸ’Ύ Benchmarking memory scaling...")
379
+
380
+ if self.model is None:
381
+ self.load_model()
382
+
383
+ if self.model is None:
384
+ msg = "Failed to load model"
385
+ raise RuntimeError(msg)
386
+
387
+ memory_results: dict[str, Any] = {}
388
+ texts = BENCHMARK_TEXTS["medium"]
389
+
390
+ baseline_memory = 0
391
+ if self.device == "cuda":
392
+ torch.cuda.empty_cache()
393
+ baseline_memory = torch.cuda.memory_allocated()
394
+
395
+ for batch_size in batch_sizes:
396
+ if batch_size > len(texts):
397
+ continue
398
+
399
+ logger.info(f" πŸ“Š Testing batch size: {batch_size}")
400
+
401
+ # Clear cache
402
+ if self.device == "cuda":
403
+ torch.cuda.empty_cache()
404
+ gc.collect()
405
+
406
+ batch_texts = texts[:batch_size]
407
+
408
+ # Measure memory before
409
+ if self.device == "cuda":
410
+ torch.cuda.memory_allocated()
411
+
412
+ # Run inference
413
+ try:
414
+ embeddings = self.model.encode(
415
+ batch_texts,
416
+ convert_to_tensor=self.device == "cuda",
417
+ show_progress_bar=False,
418
+ )
419
+
420
+ # Measure memory after
421
+ memory_after = 0
422
+ if self.device == "cuda":
423
+ memory_after = torch.cuda.max_memory_allocated()
424
+ torch.cuda.reset_peak_memory_stats()
425
+
426
+ memory_used_mb = (memory_after - baseline_memory) / (1024 * 1024)
427
+ memory_per_text_mb = memory_used_mb / batch_size if batch_size > 0 else 0
428
+
429
+ memory_results[f"batch_{batch_size}"] = {
430
+ "memory_used_mb": memory_used_mb,
431
+ "memory_per_text_mb": memory_per_text_mb,
432
+ "baseline_memory_mb": baseline_memory / (1024 * 1024),
433
+ "peak_memory_mb": memory_after / (1024 * 1024),
434
+ }
435
+
436
+ # Clean up
437
+ del embeddings
438
+
439
+ except torch.cuda.OutOfMemoryError:
440
+ logger.warning(f"❌ OOM at batch size {batch_size}")
441
+ memory_results[f"batch_{batch_size}"] = {"oom": True}
442
+ break
443
+ except Exception as e:
444
+ logger.warning(f"❌ Error at batch size {batch_size}: {e}")
445
+ memory_results[f"batch_{batch_size}"] = {"error": str(e)}
446
+
447
+ self.results["memory_benchmarks"] = memory_results
448
+ return memory_results
449
+
450
+ def benchmark_cpu_vs_gpu(self) -> dict[str, Any]:
451
+ """Compare CPU vs GPU performance."""
452
+ logger.info("πŸ–₯️ Benchmarking CPU vs GPU performance...")
453
+
454
+ comparison_results = {}
455
+ test_texts = BENCHMARK_TEXTS["medium"][:32] # Fixed batch size
456
+
457
+ devices = ["cpu"]
458
+ if torch.cuda.is_available():
459
+ devices.append("cuda")
460
+
461
+ for device in devices:
462
+ logger.info(f" πŸ”„ Testing on {device}")
463
+
464
+ # Load model on device
465
+ try:
466
+ model = SentenceTransformer(self.model_path, device=device)
467
+
468
+ # Warmup
469
+ _ = model.encode(test_texts[:2], convert_to_tensor=False)
470
+
471
+ # Benchmark
472
+ start_time = time.perf_counter()
473
+ embeddings = model.encode(test_texts, convert_to_tensor=False, show_progress_bar=False)
474
+ end_time = time.perf_counter()
475
+
476
+ total_time = end_time - start_time
477
+
478
+ comparison_results[device] = {
479
+ "total_time_ms": total_time * 1000,
480
+ "texts_per_second": len(test_texts) / total_time,
481
+ "time_per_text_ms": (total_time / len(test_texts)) * 1000,
482
+ "embedding_shape": embeddings.shape
483
+ if hasattr(embeddings, "shape")
484
+ else f"({len(embeddings)}, {len(embeddings[0]) if embeddings else 0})",
485
+ }
486
+
487
+ del model
488
+ if device == "cuda":
489
+ torch.cuda.empty_cache()
490
+
491
+ except Exception as e:
492
+ logger.warning(f"❌ Failed on {device}: {e}")
493
+ comparison_results[device] = {"error": str(e)}
494
+
495
+ self.results["cpu_vs_gpu"] = comparison_results
496
+ return comparison_results
497
+
498
+ def run_comprehensive_benchmark(self) -> dict[str, Any]:
499
+ """Run all benchmarks and return comprehensive results."""
500
+ logger.info(f"πŸš€ Starting comprehensive benchmark for {self.model_name}")
501
+
502
+ # Load model
503
+ self.load_model()
504
+
505
+ # Run all benchmarks
506
+ self.measure_model_size()
507
+ self.benchmark_inference_speed()
508
+ self.benchmark_memory_scaling()
509
+ self.benchmark_cpu_vs_gpu()
510
+
511
+ # Add metadata
512
+ self.results["model_name"] = self.model_name
513
+ self.results["model_path"] = self.model_path
514
+ self.results["device"] = self.device
515
+ self.results["torch_version"] = torch.__version__
516
+ self.results["cuda_available"] = torch.cuda.is_available()
517
+
518
+ if torch.cuda.is_available():
519
+ self.results["gpu_name"] = torch.cuda.get_device_name(0)
520
+ self.results["gpu_memory_gb"] = torch.cuda.get_device_properties(0).total_memory / (1024**3)
521
+
522
+ # System info
523
+ self.results["cpu_count"] = psutil.cpu_count()
524
+ self.results["ram_gb"] = psutil.virtual_memory().total / (1024**3)
525
+
526
+ logger.info("βœ… Comprehensive benchmark completed!")
527
+ return self.results
528
+
529
+ def save_results(self, output_file: str) -> None:
530
+ """Save benchmark results to JSON file."""
531
+ output_path = Path(output_file)
532
+ output_path.parent.mkdir(parents=True, exist_ok=True)
533
+
534
+ with output_path.open("w") as f:
535
+ json.dump(self.results, f, indent=2, default=str)
536
+
537
+ logger.info(f"πŸ“„ Results saved to {output_path}")
538
+
539
+ def print_summary(self) -> None:
540
+ """Print a summary of benchmark results."""
541
+ if not self.results:
542
+ logger.warning("No results to summarize")
543
+ return
544
+
545
+ print(f"\n{'=' * 60}")
546
+ print(f"Performance Benchmark Summary: {self.model_name}")
547
+ print(f"{'=' * 60}")
548
+
549
+ # Model size
550
+ if "size_metrics" in self.results:
551
+ size = self.results["size_metrics"]
552
+ print("\nπŸ“ Model Size:")
553
+ print(f" Disk Size: {size.get('disk_size_mb', 0):.1f} MB")
554
+ if "parameters_millions" in size:
555
+ print(f" Parameters: {size['parameters_millions']:.1f}M")
556
+ if "embedding_dim" in size:
557
+ print(f" Embedding Dim: {size['embedding_dim']}")
558
+
559
+ # Speed summary
560
+ if "speed_benchmarks" in self.results:
561
+ speed = self.results["speed_benchmarks"]
562
+ print("\n⚑ Speed (medium texts, batch 32):")
563
+ if "medium" in speed and "batch_32" in speed["medium"]:
564
+ batch_32 = speed["medium"]["batch_32"]
565
+ print(f" Throughput: {batch_32['texts_per_second']:.1f} texts/sec")
566
+ print(f" Latency: {batch_32['time_per_text_ms']:.1f} ms/text")
567
+ print(f" Token Speed: {batch_32['tokens_per_second']:.0f} tokens/sec")
568
+
569
+ # CPU vs GPU
570
+ if "cpu_vs_gpu" in self.results:
571
+ comparison = self.results["cpu_vs_gpu"]
572
+ print("\nπŸ–₯️ CPU vs GPU:")
573
+ for device, metrics in comparison.items():
574
+ if "error" not in metrics:
575
+ print(f" {device.upper()}: {metrics['texts_per_second']:.1f} texts/sec")
576
+
577
+ print()
578
+
579
+
580
+ def run_benchmark(
581
+ model_path: str | list[str],
582
+ model_name: str | None = None,
583
+ output: str = "benchmark_results.json",
584
+ quick: bool = False,
585
+ compare_models: list[str] | None = None,
586
+ ) -> None:
587
+ """Run benchmark for one or multiple models with comparison."""
588
+ # Handle both single model and multiple models
589
+ models_to_benchmark = [model_path] if isinstance(model_path, str) else model_path
590
+
591
+ if compare_models:
592
+ models_to_benchmark.extend(compare_models)
593
+
594
+ all_results = []
595
+
596
+ for i, model in enumerate(models_to_benchmark):
597
+ current_model_name = model_name if i == 0 else Path(model).name
598
+
599
+ print(f"\n{'=' * 60}")
600
+ print(f"Benchmarking Model {i + 1}/{len(models_to_benchmark)}: {current_model_name}")
601
+ print(f"{'=' * 60}")
602
+
603
+ try:
604
+ benchmarker = PerformanceBenchmark(model, current_model_name)
605
+
606
+ if quick:
607
+ # Quick benchmark
608
+ benchmarker.load_model()
609
+ benchmarker.measure_model_size()
610
+ benchmarker.benchmark_inference_speed([1, 16, 32])
611
+ else:
612
+ # Comprehensive benchmark
613
+ benchmarker.run_comprehensive_benchmark()
614
+
615
+ all_results.append(benchmarker.results)
616
+ benchmarker.print_summary()
617
+
618
+ except Exception:
619
+ logger.exception(f"❌ Failed to benchmark {current_model_name}")
620
+ continue
621
+
622
+ # Save individual results
623
+ output_dir = Path(output).parent if Path(output).suffix else Path(output)
624
+ output_dir.mkdir(parents=True, exist_ok=True)
625
+
626
+ for results in all_results:
627
+ model_name_safe = "".join(c for c in results["model_name"] if c.isalnum() or c in ("-", "_", "."))
628
+ output_path = output_dir / f"benchmark_{model_name_safe}.json"
629
+
630
+ with output_path.open("w") as f:
631
+ json.dump(results, f, indent=2, default=str)
632
+
633
+ logger.info(f"πŸ“„ Results saved to {output_path}")
634
+
635
+ # Create comparison if multiple models
636
+ if len(all_results) > 1:
637
+ create_benchmark_comparison(all_results, str(output_dir / "benchmark_comparison.json"))
638
+
639
+ print(f"\nβœ… Benchmark complete! Results saved to {output_dir}")
640
+
641
+
642
+ def create_benchmark_comparison(all_results: list[dict[str, Any]], output_path: str) -> None:
643
+ """Create a comparison report for multiple benchmark results."""
644
+ print(f"\n{'=' * 80}")
645
+ print("Performance Benchmark Comparison")
646
+ print(f"{'=' * 80}")
647
+
648
+ comparison_data = []
649
+
650
+ for results in all_results:
651
+ model_name = results.get("model_name", "Unknown")
652
+ size_metrics = results.get("size_metrics", {})
653
+ speed_benchmarks = results.get("speed_benchmarks", {})
654
+ cpu_vs_gpu = results.get("cpu_vs_gpu", {})
655
+
656
+ # Extract key metrics
657
+ row = {
658
+ "Model": model_name,
659
+ "Disk Size (MB)": size_metrics.get("disk_size_mb", 0),
660
+ "Parameters (M)": size_metrics.get("parameters_millions", 0),
661
+ "Embedding Dim": size_metrics.get("embedding_dim", 0),
662
+ }
663
+
664
+ # Speed metrics (medium texts, batch 32)
665
+ if "medium" in speed_benchmarks and "batch_32" in speed_benchmarks["medium"]:
666
+ batch_32 = speed_benchmarks["medium"]["batch_32"]
667
+ row.update(
668
+ {
669
+ "Throughput (texts/sec)": batch_32.get("texts_per_second", 0),
670
+ "Latency (ms/text)": batch_32.get("time_per_text_ms", 0),
671
+ "Token Speed (tokens/sec)": batch_32.get("tokens_per_second", 0),
672
+ }
673
+ )
674
+
675
+ # CPU vs GPU comparison
676
+ for device in ["cpu", "cuda"]:
677
+ if device in cpu_vs_gpu and "error" not in cpu_vs_gpu[device]:
678
+ row[f"{device.upper()} Speed (texts/sec)"] = cpu_vs_gpu[device].get("texts_per_second", 0)
679
+
680
+ comparison_data.append(row)
681
+
682
+ # Create DataFrame and save
683
+ df = pd.DataFrame(comparison_data)
684
+
685
+ # Sort by throughput (descending)
686
+ if "Throughput (texts/sec)" in df.columns:
687
+ df = df.sort_values("Throughput (texts/sec)", ascending=False)
688
+
689
+ # Print comparison table
690
+ print(df.to_string(index=False, float_format="%.2f"))
691
+
692
+ # Save comparison results
693
+ comparison_summary = {
694
+ "comparison_table": df.to_dict(orient="records"),
695
+ "summary": {
696
+ "fastest_model": df.iloc[0]["Model"] if len(df) > 0 else None,
697
+ "smallest_model": df.loc[df["Disk Size (MB)"].idxmin()]["Model"] if len(df) > 0 else None,
698
+ "most_efficient": df.loc[df["Throughput (texts/sec)"].idxmax()]["Model"]
699
+ if "Throughput (texts/sec)" in df.columns and len(df) > 0
700
+ else None,
701
+ },
702
+ "timestamp": time.time(),
703
+ }
704
+
705
+ with Path(output_path).open("w") as f:
706
+ json.dump(comparison_summary, f, indent=2, default=str)
707
+
708
+ print(f"\nπŸ“Š Comparison saved to {output_path}")
709
+
710
+
711
+ def save_benchmark_results(
712
+ results: dict[str, Any],
713
+ output_dir: str,
714
+ model_name: str,
715
+ volume_results_dir: Path | None = None,
716
+ ) -> None:
717
+ """Save benchmark results to JSON file with Beam volume support."""
718
+ # Save to Beam volume if available
719
+ if volume_results_dir:
720
+ volume_output_path = volume_results_dir / f"benchmark_{model_name}.json"
721
+ try:
722
+ with volume_output_path.open("w") as f:
723
+ json.dump(results, f, indent=2, default=str)
724
+ logger.info(f"πŸ’Ύ Results saved to Beam volume: {volume_output_path}")
725
+ except Exception as e:
726
+ logger.warning(f"⚠️ Failed to save to Beam volume: {e}")
727
+
728
+ # Always save local backup
729
+ output_path = Path(output_dir)
730
+ output_path.mkdir(parents=True, exist_ok=True)
731
+
732
+ # Clean model name for filename
733
+ safe_name = "".join(c for c in model_name if c.isalnum() or c in ("-", "_", "."))
734
+ filename = f"benchmark_{safe_name}.json"
735
+ filepath = output_path / filename
736
+
737
+ with filepath.open("w") as f:
738
+ json.dump(results, f, indent=2, default=str)
739
+
740
+ logger.info(f"πŸ“„ Local backup saved to {filepath}")
741
+
742
+
743
+ def beam_benchmark_models(
744
+ models: list[str],
745
+ quick: bool = False,
746
+ output_dir: str = DEFAULT_OUTPUT_DIR,
747
+ volume_name: str = VOLUME_NAME,
748
+ mount_path: str = VOLUME_PATH,
749
+ ) -> list[dict[str, Any]]:
750
+ """Main benchmarking function for Beam execution with checkpoint support."""
751
+ logger.info("πŸš€ Starting Beam-powered performance benchmarking")
752
+ logger.info(f"πŸ“Š Benchmarking {len(models)} models")
753
+
754
+ # Initialize Beam utilities
755
+ volume_mgr, checkpoint_mgr, model_mgr, eval_mgr = create_beam_utilities(volume_name, mount_path)
756
+
757
+ # Create benchmark results directory in volume
758
+ results_dir = Path(mount_path) / BENCHMARK_RESULTS_DIR
759
+ results_dir.mkdir(parents=True, exist_ok=True)
760
+
761
+ logger.info(f"πŸ“ Using Beam volume: {volume_name} at {mount_path}")
762
+ logger.info(f"πŸ’Ύ Benchmark results directory: {results_dir}")
763
+
764
+ all_results = []
765
+ skipped_models = []
766
+
767
+ for model_path in models:
768
+ model_name = Path(model_path).name if model_path != str(Path(mount_path)) else "gte_qwen2_m2v_code"
769
+
770
+ # Check if this model has already been benchmarked (except for trained model)
771
+ is_trained_model = model_path == str(Path(mount_path)) or model_name == "gte_qwen2_m2v_code"
772
+
773
+ if not is_trained_model:
774
+ # Check for existing benchmark results
775
+ existing_result_file = results_dir / f"benchmark_{model_name}.json"
776
+ if existing_result_file.exists():
777
+ logger.info(f"βœ… Model {model_name} already benchmarked - loading existing results")
778
+ try:
779
+ with existing_result_file.open("r") as f:
780
+ existing_results = json.load(f)
781
+ all_results.append(existing_results)
782
+ skipped_models.append(model_name)
783
+ continue
784
+ except Exception as e:
785
+ logger.warning(f"⚠️ Failed to load existing results for {model_name}: {e}")
786
+ # Continue with benchmarking if loading fails
787
+
788
+ logger.info(f"\n{'=' * 60}")
789
+ logger.info(f"πŸ” Benchmarking model: {model_name}")
790
+ logger.info(f"πŸ“‚ Path: {model_path}")
791
+ if is_trained_model:
792
+ logger.info("🎯 Trained model - always re-benchmark")
793
+ logger.info(f"{'=' * 60}")
794
+
795
+ try:
796
+ # Distinguish between local paths and HuggingFace model names
797
+ is_huggingface_model = (
798
+ "/" in model_path and not model_path.startswith("/") and not Path(model_path).exists()
799
+ )
800
+
801
+ if is_huggingface_model:
802
+ # This is a HuggingFace model name - pass directly to benchmarker
803
+ logger.info(f"πŸ“₯ Loading HuggingFace model: {model_path}")
804
+ benchmarker = PerformanceBenchmark(
805
+ model_path,
806
+ model_name,
807
+ checkpoint_manager=checkpoint_mgr,
808
+ eval_manager=eval_mgr,
809
+ )
810
+ else:
811
+ # This is a local path - check if it exists in Beam volume
812
+ actual_model_path = model_path # Default to original path
813
+ if not Path(model_path).exists() and not model_path.startswith("/"):
814
+ # Try to load from Beam volume
815
+ local_model_path = Path(mount_path) / model_name
816
+ logger.info(f"πŸ” Trying to load {model_name} from Beam volume: {local_model_path}")
817
+ if local_model_path.exists():
818
+ actual_model_path = str(local_model_path)
819
+ logger.info(f"βœ… Found model in Beam volume: {actual_model_path}")
820
+ else:
821
+ # Try in root of volume (for your trained model)
822
+ root_model_path = Path(mount_path)
823
+ if (root_model_path / "config.json").exists():
824
+ actual_model_path = str(root_model_path)
825
+ logger.info(f"βœ… Found model in Beam volume root: {actual_model_path}")
826
+ else:
827
+ logger.warning(f"⚠️ Model not found locally or in Beam volume: {model_name}")
828
+ continue
829
+
830
+ benchmarker = PerformanceBenchmark(
831
+ actual_model_path,
832
+ model_name,
833
+ checkpoint_manager=checkpoint_mgr,
834
+ eval_manager=eval_mgr,
835
+ )
836
+
837
+ # Run benchmarking
838
+ if quick:
839
+ # Quick benchmark
840
+ benchmarker.load_model()
841
+ benchmarker.measure_model_size()
842
+ benchmarker.benchmark_inference_speed([1, 16, 32])
843
+ else:
844
+ # Comprehensive benchmark
845
+ benchmarker.run_comprehensive_benchmark()
846
+
847
+ # Save results with Beam support
848
+ save_benchmark_results(benchmarker.results, output_dir, model_name, results_dir)
849
+
850
+ # Print summary
851
+ benchmarker.print_summary()
852
+
853
+ all_results.append(benchmarker.results)
854
+
855
+ except Exception:
856
+ logger.exception(f"❌ Failed to benchmark {model_name}")
857
+ continue
858
+
859
+ # Create comparison report in Beam volume
860
+ if len(all_results) > 1:
861
+ comparison_dir = results_dir / "comparisons"
862
+ comparison_dir.mkdir(parents=True, exist_ok=True)
863
+ create_benchmark_comparison(all_results, str(comparison_dir / "benchmark_comparison.json"))
864
+ logger.info(f"πŸ“Š Comparison report saved to Beam volume: {comparison_dir}")
865
+
866
+ # Log summary of what was done
867
+ newly_benchmarked = len(all_results) - len(skipped_models)
868
+ logger.info("\nβœ… Beam benchmarking complete!")
869
+ logger.info(f"πŸ“Š Newly benchmarked: {newly_benchmarked} models")
870
+ logger.info(f"⏭️ Skipped (already done): {len(skipped_models)} models")
871
+ logger.info(f"πŸ“ Total results: {len(all_results)} models")
872
+ logger.info(f"πŸ’Ύ Results available in Beam volume: {volume_name}")
873
+
874
+ if skipped_models:
875
+ logger.info(f"⏭️ Skipped models: {', '.join(skipped_models)}")
876
+
877
+ return all_results
878
+
879
+
880
+ @function(
881
+ gpu=GPU_NAME,
882
+ volumes=[Volume(name=VOLUME_NAME, mount_path=VOLUME_PATH)],
883
+ image=IMAGE,
884
+ secrets=["HF_ACCESS_TOKEN"],
885
+ env={
886
+ "TOKENIZERS_PARALLELISM": "false",
887
+ "CUDA_LAUNCH_BLOCKING": "0",
888
+ },
889
+ timeout=3600 * 4, # 4 hours for benchmarking all models
890
+ )
891
+ def main() -> None:
892
+ """Main benchmarking function - runs all default models on Beam."""
893
+ logger.info("πŸš€ Starting comprehensive performance benchmarking on Beam")
894
+
895
+ # Use default models but replace the local model path with Beam volume path
896
+ models = DEFAULT_BENCHMARK_MODELS.copy()
897
+
898
+ # Replace "gte_qwen2_m2v_code" with actual Beam volume path
899
+ for i, model in enumerate(models):
900
+ if model == "gte_qwen2_m2v_code":
901
+ models[i] = str(Path(VOLUME_PATH)) # Use the Beam volume root
902
+ logger.info(f"🎯 Using trained model from Beam volume: {models[i]}")
903
+
904
+ # Discover simplified distillation models
905
+ logger.info("πŸ” Discovering simplified distillation models...")
906
+ discovered_models = discover_simplified_models(".")
907
+
908
+ # Add discovered models
909
+ if discovered_models:
910
+ logger.info(f"βœ… Found {len(discovered_models)} simplified models:")
911
+ for model_path in discovered_models:
912
+ models.append(model_path)
913
+ logger.info(f" πŸ“ {model_path}")
914
+ else:
915
+ logger.warning("⚠️ No simplified distillation models found")
916
+
917
+ logger.info(f"πŸ“Š Benchmarking {len(models)} models:")
918
+ for i, model in enumerate(models, 1):
919
+ logger.info(f" {i}. {model}")
920
+
921
+ logger.info("\nπŸ’‘ Checkpoint Info:")
922
+ logger.info(" - Already benchmarked models will be skipped")
923
+ logger.info(" - Your trained model will always be re-benchmarked")
924
+ logger.info(" - Results are saved persistently to Beam volume")
925
+
926
+ # Run comprehensive benchmark using Beam utilities
927
+ results = beam_benchmark_models(
928
+ models=models,
929
+ quick=True, # Use quick benchmark for efficiency
930
+ output_dir=str(Path(VOLUME_PATH) / BENCHMARK_RESULTS_DIR),
931
+ volume_name=VOLUME_NAME,
932
+ mount_path=VOLUME_PATH,
933
+ )
934
+
935
+ # Print final summary
936
+ print("\n🎯 Benchmarking Summary:")
937
+ print(f"πŸ“Š Total models processed: {len(results)}")
938
+ print(f"πŸ’Ύ Results saved to Beam volume: {VOLUME_NAME}")
939
+ print(f"πŸ“ Directory: {BENCHMARK_RESULTS_DIR}")
940
+ print("\nπŸ” To view analysis:")
941
+ print(" beam run src.distiller.analyze:beam_analysis")
942
+ print("\nπŸ“ˆ To run benchmarks again:")
943
+ print(" distiller benchmark (will skip already completed models)")
944
+
945
+
946
+ def discover_simplified_models(base_path: str = ".") -> list[str]:
947
+ """
948
+ Discover all simplified distillation models in the correct directory.
949
+
950
+ Looks for directories matching the pattern: ./code_model2vec/final/code_model2vec_*
951
+ """
952
+ discovered_models: list[str] = []
953
+
954
+ # Look in the correct location where distill_simplified.py saves models
955
+ models_dir = Path(base_path) / "code_model2vec" / "final"
956
+
957
+ if not models_dir.exists():
958
+ logger.warning(f"Models directory not found: {models_dir}")
959
+ return discovered_models
960
+
961
+ # Look for simplified model directories with the updated pattern
962
+ pattern = "code_model2vec_*"
963
+ for model_dir in models_dir.glob(pattern):
964
+ if model_dir.is_dir() and (model_dir / "config.json").exists():
965
+ discovered_models.append(str(model_dir))
966
+ logger.info(f"πŸ” Discovered simplified model: {model_dir}")
967
+
968
+ # Sort alphabetically for consistent ordering
969
+ discovered_models.sort()
970
+
971
+ return discovered_models
972
+
973
+
974
+ @function(
975
+ gpu=GPU_NAME,
976
+ volumes=[Volume(name=VOLUME_NAME, mount_path=VOLUME_PATH)],
977
+ image=IMAGE,
978
+ secrets=["HF_ACCESS_TOKEN"],
979
+ env={
980
+ "TOKENIZERS_PARALLELISM": "false",
981
+ "CUDA_LAUNCH_BLOCKING": "0",
982
+ },
983
+ timeout=3600 * 3, # 3 hours for simplified models only
984
+ )
985
+ def benchmark_simplified_only() -> None:
986
+ """Benchmark only simplified distillation models, skipping 3rd party models."""
987
+ logger.info("πŸš€ Starting simplified distillation models benchmarking on Beam")
988
+ logger.info("⏭️ Skipping 3rd party models - benchmarking only simplified distillation models")
989
+
990
+ # Discover simplified distillation models
991
+ logger.info("πŸ” Discovering simplified distillation models...")
992
+ discovered_models = discover_simplified_models(".")
993
+
994
+ if not discovered_models:
995
+ logger.error("❌ No simplified distillation models found! Run distill-simple first.")
996
+ return
997
+
998
+ logger.info(f"βœ… Found {len(discovered_models)} simplified models:")
999
+ for model_path in discovered_models:
1000
+ logger.info(f" πŸ“ {model_path}")
1001
+
1002
+ logger.info("\nπŸ’‘ Checkpoint Info:")
1003
+ logger.info(" - Already benchmarked models will be skipped")
1004
+ logger.info(" - Results are saved persistently to Beam volume")
1005
+
1006
+ # Run comprehensive benchmark using Beam utilities
1007
+ results = beam_benchmark_models(
1008
+ models=discovered_models,
1009
+ quick=True, # Use quick benchmark for efficiency
1010
+ output_dir=str(Path(VOLUME_PATH) / BENCHMARK_RESULTS_DIR),
1011
+ volume_name=VOLUME_NAME,
1012
+ mount_path=VOLUME_PATH,
1013
+ )
1014
+
1015
+ # Print final summary
1016
+ print("\n🎯 Simplified Benchmarking Summary:")
1017
+ print(f"πŸ“Š Total simplified models processed: {len(results)}")
1018
+ print(f"πŸ’Ύ Results saved to Beam volume: {VOLUME_NAME}")
1019
+ print(f"πŸ“ Directory: {BENCHMARK_RESULTS_DIR}")
1020
+ print("⏭️ 3rd party models were skipped")
1021
+ print("\nπŸ” To view analysis:")
1022
+ print(" distiller analyze")
1023
+ print("\nπŸ“ˆ To run full benchmarks (including 3rd party):")
1024
+ print(" distiller benchmark")
1025
+
1026
+
1027
+ def run_local_benchmark(
1028
+ models: list[str] | None = None,
1029
+ quick: bool = False,
1030
+ output_dir: str = DEFAULT_OUTPUT_DIR,
1031
+ ) -> list[dict[str, Any]]:
1032
+ """Main benchmarking function for local execution without Beam utilities."""
1033
+ logger.info("πŸ–₯️ Running performance benchmarking locally")
1034
+
1035
+ if models is None:
1036
+ models = DEFAULT_BENCHMARK_MODELS.copy()
1037
+
1038
+ # Replace "gte_qwen2_m2v_code" with a reasonable local path
1039
+ for i, model in enumerate(models):
1040
+ if model == "gte_qwen2_m2v_code":
1041
+ # Look for local trained model
1042
+ local_model_paths = [
1043
+ "./gte_qwen2_m2v_code",
1044
+ "./models/gte_qwen2_m2v_code",
1045
+ "./output/gte_qwen2_m2v_code",
1046
+ ]
1047
+ found = False
1048
+ for local_path in local_model_paths:
1049
+ if Path(local_path).exists():
1050
+ models[i] = local_path
1051
+ logger.info(f"🎯 Found local trained model: {local_path}")
1052
+ found = True
1053
+ break
1054
+ if not found:
1055
+ logger.warning("⚠️ Local trained model not found, skipping")
1056
+ models.pop(i)
1057
+
1058
+ # Discover simplified distillation models
1059
+ logger.info("πŸ” Discovering simplified distillation models...")
1060
+ discovered_models = discover_simplified_models(".")
1061
+
1062
+ # Add discovered models
1063
+ if discovered_models:
1064
+ logger.info(f"βœ… Found {len(discovered_models)} simplified models:")
1065
+ for model_path in discovered_models:
1066
+ models.append(model_path)
1067
+ logger.info(f" πŸ“ {model_path}")
1068
+ else:
1069
+ logger.warning("⚠️ No simplified distillation models found")
1070
+
1071
+ logger.info(f"πŸ“Š Benchmarking {len(models)} models")
1072
+ logger.info(f"πŸ“ Using local output directory: {output_dir}")
1073
+
1074
+ # Create local output directory
1075
+ output_path = Path(output_dir)
1076
+ output_path.mkdir(parents=True, exist_ok=True)
1077
+
1078
+ all_results = []
1079
+ skipped_models = []
1080
+
1081
+ for model_path in models:
1082
+ model_name = Path(model_path).name
1083
+
1084
+ # Check for existing benchmark results locally
1085
+ safe_name = "".join(c for c in model_name if c.isalnum() or c in ("-", "_", "."))
1086
+ result_file = output_path / f"benchmark_{safe_name}.json"
1087
+
1088
+ if result_file.exists():
1089
+ logger.info(f"βœ… Model {model_name} already benchmarked - loading existing results")
1090
+ try:
1091
+ with result_file.open("r") as f:
1092
+ existing_results = json.load(f)
1093
+ all_results.append(existing_results)
1094
+ skipped_models.append(model_name)
1095
+ continue
1096
+ except Exception as e:
1097
+ logger.warning(f"⚠️ Failed to load existing results for {model_name}: {e}")
1098
+
1099
+ logger.info(f"\n{'=' * 60}")
1100
+ logger.info(f"πŸ” Benchmarking model: {model_name}")
1101
+ logger.info(f"πŸ“‚ Path: {model_path}")
1102
+ logger.info(f"{'=' * 60}")
1103
+
1104
+ try:
1105
+ # Create benchmarker without Beam utilities
1106
+ benchmarker = PerformanceBenchmark(
1107
+ model_path,
1108
+ model_name,
1109
+ checkpoint_manager=None, # No checkpointing for local benchmarking
1110
+ eval_manager=None,
1111
+ )
1112
+
1113
+ # Run benchmarking
1114
+ if quick:
1115
+ # Quick benchmark
1116
+ benchmarker.load_model()
1117
+ benchmarker.measure_model_size()
1118
+ benchmarker.benchmark_inference_speed([1, 16, 32])
1119
+ else:
1120
+ # Comprehensive benchmark
1121
+ benchmarker.run_comprehensive_benchmark()
1122
+
1123
+ # Save results locally only
1124
+ save_benchmark_results(benchmarker.results, output_dir, model_name, volume_results_dir=None)
1125
+
1126
+ # Print summary
1127
+ benchmarker.print_summary()
1128
+
1129
+ all_results.append(benchmarker.results)
1130
+
1131
+ except Exception:
1132
+ logger.exception(f"❌ Failed to benchmark {model_name}")
1133
+ continue
1134
+
1135
+ # Create comparison report locally
1136
+ if len(all_results) > 1:
1137
+ create_benchmark_comparison(all_results, str(output_path / "benchmark_comparison.json"))
1138
+ logger.info(f"πŸ“Š Comparison report saved locally: {output_dir}")
1139
+
1140
+ # Log summary
1141
+ newly_benchmarked = len(all_results) - len(skipped_models)
1142
+ logger.info("\nβœ… Local benchmarking complete!")
1143
+ logger.info(f"πŸ“Š Newly benchmarked: {newly_benchmarked} models")
1144
+ logger.info(f"⏭️ Skipped (already done): {len(skipped_models)} models")
1145
+ logger.info(f"πŸ“ Total results: {len(all_results)} models")
1146
+ logger.info(f"πŸ’Ύ Results available locally: {output_dir}")
1147
+
1148
+ if skipped_models:
1149
+ logger.info(f"⏭️ Skipped models: {', '.join(skipped_models)}")
1150
+
1151
+ return all_results
1152
+
1153
+
1154
+ def run_local_benchmark_simplified(
1155
+ quick: bool = False,
1156
+ output_dir: str = DEFAULT_OUTPUT_DIR,
1157
+ ) -> list[dict[str, Any]]:
1158
+ """Local benchmarking function for simplified models only."""
1159
+ logger.info("πŸ–₯️ Running simplified model benchmarking locally")
1160
+
1161
+ # Discover simplified distillation models only
1162
+ logger.info("πŸ” Discovering simplified distillation models...")
1163
+ discovered_models = discover_simplified_models(".")
1164
+
1165
+ if not discovered_models:
1166
+ logger.error("❌ No simplified distillation models found! Run 'distiller distill-simple' first.")
1167
+ return []
1168
+
1169
+ logger.info(f"βœ… Found {len(discovered_models)} simplified models:")
1170
+ for model_path in discovered_models:
1171
+ logger.info(f" πŸ“ {model_path}")
1172
+
1173
+ return run_local_benchmark(
1174
+ models=discovered_models,
1175
+ quick=quick,
1176
+ output_dir=output_dir,
1177
+ )
1178
+
1179
+
1180
+ if __name__ == "__main__":
1181
+ main()
src/distiller/distill.py ADDED
@@ -0,0 +1,1306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Code-Specialized Model2Vec Distillation Script with Checkpoint Support.
3
+
4
+ This script implements a focused approach for creating code-specialized embeddings
5
+ using Model2Vec distillation with one additional training round on code-specific tasks.
6
+
7
+ Features:
8
+ - Incremental checkpoint saving
9
+ - Resume from previous progress
10
+ - Persistent storage of embeddings and models
11
+ - Robust error handling and recovery
12
+ - Smart checkpoint validation for parameter compatibility
13
+
14
+ Approach:
15
+ 1. Basic Model2Vec distillation with optimized parameters
16
+ 2. Single code specialization round using sentence-transformers/codesearchnet dataset
17
+ """
18
+
19
+ import json
20
+ import logging
21
+ import os
22
+ import time
23
+ from pathlib import Path
24
+ from typing import Any
25
+
26
+ import numpy as np
27
+ import torch
28
+ from beam import GpuType, Image, Volume, function
29
+ from datasets import load_dataset
30
+ from model2vec.distill import distill
31
+ from model2vec.train.base import FinetunableStaticModel, TextDataset
32
+ from sentence_transformers import SentenceTransformer
33
+ from sklearn.model_selection import train_test_split
34
+ from torch import nn, optim
35
+
36
+ from .beam_utils import (
37
+ BeamCheckpointManager,
38
+ BeamModelManager,
39
+ create_beam_utilities,
40
+ )
41
+
42
+ # =============================================================================
43
+ # CODE-FOCUSED CONFIGURATION
44
+ # =============================================================================
45
+
46
+ # Model Configuration
47
+ MODEL_NAME = "Alibaba-NLP/gte-Qwen2-7B-instruct"
48
+ OUTPUT_DIR = "gte_qwen2_m2v_code"
49
+ CHECKPOINT_DIR = "gte_qwen2_m2v_code/checkpoints"
50
+
51
+ # Code-optimized parameters
52
+ PCA_DIMS = 512 # Higher dims for code complexity
53
+ TRAINING_EPOCHS = 2
54
+ LEARNING_RATE = 1e-4
55
+ BATCH_SIZE = 32
56
+ REGULARIZATION_WEIGHT = 0.01
57
+
58
+ # CodeSearchNet dataset configuration
59
+ CODESEARCHNET_DATASET = "sentence-transformers/codesearchnet"
60
+ MAX_TRAINING_SAMPLES = 50000 # Limit for manageable training time
61
+
62
+ # Checkpoint configuration
63
+ CHECKPOINT_INTERVAL = 1000 # Save every N samples
64
+ EMBEDDINGS_BATCH_SIZE = 100 # Save embeddings in smaller batches
65
+
66
+ # OPTIMIZED TEACHER MODEL CONFIGURATION FOR 40GB VRAM
67
+ TEACHER_MODEL_CONFIG: dict[str, Any] = {
68
+ "batch_size": 12, # Slightly reduced due to float32 memory usage
69
+ "precision": "float32", # Use float32 for quality preservation
70
+ "max_seq_length": 8192, # Reduce from 32k default for better performance
71
+ "device_map": "auto", # Automatic device placement
72
+ "torch_dtype": torch.float32, # Use float32 for quality preservation
73
+ "trust_remote_code": True,
74
+ "use_flash_attention": True, # Try to enable flash attention if available
75
+ "attn_implementation": "flash_attention_2", # Use flash attention 2 if available
76
+ }
77
+
78
+ # =============================================================================
79
+ # BEAM CONFIGURATION
80
+ # =============================================================================
81
+
82
+ GPU_NAME = GpuType.A100_40
83
+ VOLUME_NAME = "gte_qwen2_m2v_code"
84
+ VOLUME_PATH = "./gte_qwen2_m2v_code"
85
+ IMAGE = Image(python_version="python3.12").add_python_packages(
86
+ [
87
+ "torch>=2.7.0", # Install torch first
88
+ "transformers>=4.40.0", # Latest transformers with flash attention support
89
+ "accelerate>=1.7.0",
90
+ "datasets>=3.2.0",
91
+ "model2vec[train]>=0.5.0",
92
+ "numpy>=1.26.4",
93
+ "scikit-learn>=1.6.1",
94
+ "sentence-transformers>=4.1.0",
95
+ ]
96
+ )
97
+
98
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
99
+ logger = logging.getLogger(__name__)
100
+
101
+
102
+ def get_current_config_hash() -> str:
103
+ """Generate a hash of current configuration parameters for checkpoint validation."""
104
+ import hashlib
105
+
106
+ config_params = {
107
+ "model_name": MODEL_NAME,
108
+ "pca_dims": PCA_DIMS,
109
+ "precision": TEACHER_MODEL_CONFIG["precision"],
110
+ "torch_dtype": str(TEACHER_MODEL_CONFIG["torch_dtype"]),
111
+ "max_samples": MAX_TRAINING_SAMPLES,
112
+ "codesearchnet_dataset": CODESEARCHNET_DATASET,
113
+ }
114
+
115
+ config_str = str(sorted(config_params.items()))
116
+ return hashlib.md5(config_str.encode()).hexdigest()[:12] # noqa: S324
117
+
118
+
119
+ def validate_checkpoint_compatibility(checkpoint_data: dict[str, Any]) -> bool:
120
+ """
121
+ Validate if checkpoint is compatible with current configuration.
122
+
123
+ Args:
124
+ checkpoint_data: Checkpoint data dictionary
125
+
126
+ Returns:
127
+ True if compatible, False otherwise
128
+ """
129
+ current_hash = get_current_config_hash()
130
+ checkpoint_hash = checkpoint_data.get("config_hash", "")
131
+
132
+ if checkpoint_hash != current_hash:
133
+ logger.warning(f"Configuration mismatch: current={current_hash}, checkpoint={checkpoint_hash}")
134
+ return False
135
+
136
+ # Additional validation checks
137
+ checkpoint_config = checkpoint_data.get("config", {})
138
+
139
+ # Check critical parameters
140
+ if checkpoint_config.get("pca_dims") != PCA_DIMS:
141
+ logger.warning(f"PCA dimensions mismatch: current={PCA_DIMS}, checkpoint={checkpoint_config.get('pca_dims')}")
142
+ return False
143
+
144
+ if checkpoint_config.get("precision") != TEACHER_MODEL_CONFIG["precision"]:
145
+ logger.warning(
146
+ f"Precision mismatch: current={TEACHER_MODEL_CONFIG['precision']}, checkpoint={checkpoint_config.get('precision')}"
147
+ )
148
+ return False
149
+
150
+ if checkpoint_config.get("max_samples") != MAX_TRAINING_SAMPLES:
151
+ logger.warning(
152
+ f"Max samples mismatch: current={MAX_TRAINING_SAMPLES}, checkpoint={checkpoint_config.get('max_samples')}"
153
+ )
154
+ return False
155
+
156
+ logger.info("βœ… Checkpoint configuration is compatible")
157
+ return True
158
+
159
+
160
+ def create_checkpoint_data(stage: str, data: dict[str, Any], step: int = 0) -> dict[str, Any]:
161
+ """
162
+ Create checkpoint data with configuration metadata.
163
+
164
+ Args:
165
+ stage: Checkpoint stage name
166
+ data: Core checkpoint data
167
+ step: Step number
168
+
169
+ Returns:
170
+ Enhanced checkpoint data with configuration
171
+ """
172
+ return {
173
+ "config_hash": get_current_config_hash(),
174
+ "config": {
175
+ "model_name": MODEL_NAME,
176
+ "pca_dims": PCA_DIMS,
177
+ "precision": TEACHER_MODEL_CONFIG["precision"],
178
+ "torch_dtype": str(TEACHER_MODEL_CONFIG["torch_dtype"]),
179
+ "max_samples": MAX_TRAINING_SAMPLES,
180
+ "codesearchnet_dataset": CODESEARCHNET_DATASET,
181
+ },
182
+ "stage": stage,
183
+ "step": step,
184
+ "timestamp": time.time(),
185
+ "data": data,
186
+ }
187
+
188
+
189
+ def load_codesearchnet_dataset_with_resume(
190
+ max_samples: int = MAX_TRAINING_SAMPLES,
191
+ checkpoint_manager: BeamCheckpointManager | None = None,
192
+ ) -> list[str]:
193
+ """Load and format the sentence-transformers/codesearchnet dataset with resume capability."""
194
+ logger.info(f"Loading CodeSearchNet dataset from {CODESEARCHNET_DATASET}")
195
+ logger.info(f"Limiting to {max_samples} samples for training efficiency")
196
+
197
+ # Check for existing dataset checkpoint with validation
198
+ if checkpoint_manager:
199
+ checkpoint_data = checkpoint_manager.load_checkpoint("dataset", 0)
200
+ if checkpoint_data:
201
+ if validate_checkpoint_compatibility(checkpoint_data):
202
+ texts = checkpoint_data.get("data", {}).get("texts", [])
203
+ if len(texts) >= max_samples:
204
+ logger.info(f"βœ… Resumed dataset loading: {len(texts)} texts from checkpoint")
205
+ return texts[:max_samples]
206
+ logger.info(f"πŸ“‹ Partial dataset found: {len(texts)} texts, continuing from there")
207
+ start_from = len(texts)
208
+ else:
209
+ logger.warning("πŸ”„ Incompatible dataset checkpoint found, starting fresh")
210
+ # Clean up incompatible checkpoint
211
+ checkpoint_manager.cleanup_old_checkpoints("dataset", keep_latest=0)
212
+ texts = []
213
+ start_from = 0
214
+ else:
215
+ texts = []
216
+ start_from = 0
217
+ else:
218
+ texts = []
219
+ start_from = 0
220
+
221
+ try:
222
+ # Load the dataset
223
+ dataset = load_dataset(CODESEARCHNET_DATASET, split="train", streaming=True)
224
+
225
+ # Skip to where we left off
226
+ dataset_iter = iter(dataset)
227
+ for _ in range(start_from):
228
+ try:
229
+ next(dataset_iter)
230
+ except StopIteration:
231
+ break
232
+
233
+ for i, example in enumerate(dataset_iter, start=start_from):
234
+ if len(texts) >= max_samples:
235
+ break
236
+
237
+ comment = example.get("comment", "").strip()
238
+ code = example.get("code", "").strip()
239
+
240
+ if comment and code and len(comment) > 10 and len(code) > 50:
241
+ # Format as comment-code pair for training
242
+ text = f"Comment: {comment}\nCode:\n{code}"
243
+
244
+ # Ensure reasonable length
245
+ if len(text) <= 2048: # Reasonable limit for embedding models
246
+ texts.append(text)
247
+
248
+ # Save checkpoint periodically
249
+ if checkpoint_manager and (i + 1) % CHECKPOINT_INTERVAL == 0:
250
+ checkpoint_data = create_checkpoint_data("dataset", {"texts": texts}, 0)
251
+ checkpoint_manager.save_checkpoint("dataset", checkpoint_data, 0)
252
+ logger.info(f"πŸ’Ύ Saved dataset checkpoint: {len(texts)} texts collected")
253
+
254
+ if (i + 1) % 10000 == 0:
255
+ logger.info(f"Processed {i + 1} examples, collected {len(texts)} valid pairs")
256
+
257
+ # Final checkpoint save
258
+ if checkpoint_manager:
259
+ checkpoint_data = create_checkpoint_data("dataset", {"texts": texts}, 0)
260
+ checkpoint_manager.save_checkpoint("dataset", checkpoint_data, 0)
261
+
262
+ logger.info(f"Successfully loaded {len(texts)} code-comment pairs from CodeSearchNet")
263
+ return texts
264
+
265
+ except Exception:
266
+ logger.exception("Error loading CodeSearchNet dataset")
267
+ return texts # Return what we have so far
268
+
269
+
270
+ def generate_teacher_embeddings_with_checkpoints(
271
+ teacher_model: SentenceTransformer,
272
+ texts: list[str],
273
+ checkpoint_manager: BeamCheckpointManager | None = None,
274
+ ) -> torch.Tensor:
275
+ """Generate teacher embeddings for code training with checkpoint support."""
276
+ logger.info(f"Generating teacher embeddings for {len(texts)} texts...")
277
+
278
+ # Check for existing embeddings checkpoint using torch.save format
279
+ final_embeddings = None
280
+
281
+ if checkpoint_manager:
282
+ # Try to load complete embeddings tensor directly
283
+ embeddings_path = Path(VOLUME_PATH) / "embeddings_cache.pt"
284
+ config_path = Path(VOLUME_PATH) / "embeddings_config.json"
285
+
286
+ if embeddings_path.exists() and config_path.exists():
287
+ try:
288
+ # Load config first to validate compatibility
289
+ with config_path.open("r") as f:
290
+ config_data = json.load(f)
291
+
292
+ # Create a dummy checkpoint data structure for validation
293
+ checkpoint_data = {
294
+ "config_hash": config_data.get("config_hash"),
295
+ "config": config_data.get("config", {}),
296
+ }
297
+
298
+ if validate_checkpoint_compatibility(checkpoint_data):
299
+ # Load the embeddings tensor
300
+ final_embeddings = torch.load(embeddings_path, map_location="cpu")
301
+ num_expected = config_data.get("num_texts", len(texts))
302
+
303
+ if final_embeddings.shape[0] >= num_expected:
304
+ logger.info(
305
+ f"βœ… Loaded complete embeddings from cache ({final_embeddings.shape[0]} embeddings)"
306
+ )
307
+ return final_embeddings[: len(texts)] # Return only the needed amount
308
+ logger.info(
309
+ f"⚠️ Cached embeddings incomplete ({final_embeddings.shape[0]}/{num_expected}), regenerating"
310
+ )
311
+ final_embeddings = None
312
+ else:
313
+ logger.warning("πŸ”„ Incompatible embeddings cache found, regenerating")
314
+ final_embeddings = None
315
+ except Exception as e:
316
+ logger.warning(f"Failed to load embeddings cache: {e}, regenerating...")
317
+ final_embeddings = None
318
+
319
+ # If we have complete embeddings, return them
320
+ if final_embeddings is not None:
321
+ return final_embeddings
322
+
323
+ # Generate embeddings from scratch
324
+ logger.info("Generating fresh teacher embeddings...")
325
+
326
+ # Use optimized batch size for large models with proper type casting
327
+ batch_size_raw = TEACHER_MODEL_CONFIG["batch_size"]
328
+ current_batch_size: int = batch_size_raw if isinstance(batch_size_raw, int) else 16
329
+ logger.info(f"Using optimized batch size: {current_batch_size} for 40GB VRAM (7B model)")
330
+
331
+ embeddings_list = []
332
+
333
+ for i in range(0, len(texts), current_batch_size):
334
+ batch_texts = texts[i : i + current_batch_size]
335
+
336
+ try:
337
+ # Use optimized encoding with convert_to_tensor=True for efficiency
338
+ batch_embeddings = teacher_model.encode(
339
+ batch_texts,
340
+ convert_to_tensor=True,
341
+ batch_size=current_batch_size,
342
+ show_progress_bar=False, # Reduce overhead
343
+ normalize_embeddings=True, # Pre-normalize for efficiency
344
+ )
345
+ embeddings_list.append(batch_embeddings)
346
+
347
+ if i % (current_batch_size * 10) == 0:
348
+ logger.info(f"Generated embeddings for {i + len(batch_texts)}/{len(texts)} texts")
349
+
350
+ except torch.cuda.OutOfMemoryError:
351
+ logger.warning(
352
+ f"GPU OOM with batch size {current_batch_size}, reducing to {max(1, current_batch_size // 2)}"
353
+ )
354
+
355
+ # Clear cache and reduce batch size
356
+ if torch.cuda.is_available():
357
+ torch.cuda.empty_cache()
358
+
359
+ current_batch_size = max(1, current_batch_size // 2)
360
+
361
+ # Retry with smaller batch size
362
+ batch_texts = texts[i : i + current_batch_size]
363
+ batch_embeddings = teacher_model.encode(
364
+ batch_texts,
365
+ convert_to_tensor=True,
366
+ batch_size=current_batch_size,
367
+ show_progress_bar=False,
368
+ normalize_embeddings=True,
369
+ )
370
+ embeddings_list.append(batch_embeddings)
371
+
372
+ logger.info(f"Successfully processed batch with reduced size {current_batch_size}")
373
+
374
+ # Combine all embeddings and force fp32 precision
375
+ teacher_embeddings = torch.cat(embeddings_list, dim=0)
376
+
377
+ # Ensure teacher embeddings are in fp32 for maximum quality
378
+ if teacher_embeddings.dtype != torch.float32:
379
+ logger.info(f"Converting teacher embeddings from {teacher_embeddings.dtype} to fp32")
380
+ teacher_embeddings = teacher_embeddings.to(torch.float32)
381
+
382
+ logger.info(f"Generated {teacher_embeddings.shape[0]} teacher embeddings in {teacher_embeddings.dtype}")
383
+
384
+ # Save embeddings cache using torch.save for future runs
385
+ if checkpoint_manager:
386
+ try:
387
+ embeddings_path = Path(VOLUME_PATH) / "embeddings_cache.pt"
388
+ config_path = Path(VOLUME_PATH) / "embeddings_config.json"
389
+
390
+ # Save embeddings tensor
391
+ torch.save(teacher_embeddings, embeddings_path)
392
+
393
+ # Save configuration
394
+ config_data = {
395
+ "config_hash": get_current_config_hash(),
396
+ "config": {
397
+ "model_name": MODEL_NAME,
398
+ "pca_dims": PCA_DIMS,
399
+ "precision": TEACHER_MODEL_CONFIG["precision"],
400
+ "torch_dtype": str(TEACHER_MODEL_CONFIG["torch_dtype"]),
401
+ "max_samples": MAX_TRAINING_SAMPLES,
402
+ "codesearchnet_dataset": CODESEARCHNET_DATASET,
403
+ },
404
+ "num_texts": len(texts),
405
+ "embedding_shape": list(teacher_embeddings.shape),
406
+ "timestamp": time.time(),
407
+ }
408
+
409
+ with config_path.open("w") as f:
410
+ json.dump(config_data, f, indent=2)
411
+
412
+ logger.info("πŸ’Ύ Saved embeddings cache for future runs")
413
+
414
+ except Exception as e:
415
+ logger.warning(f"Failed to save embeddings cache: {e}")
416
+
417
+ return teacher_embeddings
418
+
419
+
420
+ def refine_with_code_training(
421
+ student_model: Any,
422
+ training_texts: list[str],
423
+ teacher_embeddings: torch.Tensor,
424
+ epochs: int = 2,
425
+ checkpoint_manager: BeamCheckpointManager | None = None,
426
+ model_manager: BeamModelManager | None = None,
427
+ ) -> Any:
428
+ """Refine the student model with code-specific training."""
429
+ logger.info(f"Starting code specialization training for {epochs} epochs...")
430
+
431
+ # Validate input parameters
432
+ if student_model is None:
433
+ logger.error("student_model is None - cannot proceed with code training")
434
+ msg = "student_model cannot be None"
435
+ raise ValueError(msg)
436
+
437
+ if not hasattr(student_model, "embedding"):
438
+ logger.error(f"student_model of type {type(student_model)} does not have 'embedding' attribute")
439
+ msg = f"student_model must have 'embedding' attribute, got {type(student_model)}"
440
+ raise ValueError(msg)
441
+
442
+ logger.info(f"Student model type: {type(student_model)}")
443
+ logger.info(f"Student model embedding shape: {student_model.embedding.shape}")
444
+
445
+ try:
446
+ # Force fp32 precision throughout for maximum quality
447
+ target_dtype = torch.float32
448
+ logger.info("🎯 Enforcing fp32 precision throughout for maximum quality")
449
+
450
+ # Detect student model dtype for logging purposes
451
+ student_dtype = student_model.embedding.dtype
452
+ logger.info(f"Student model original embedding dtype: {student_dtype}")
453
+
454
+ # Force teacher embeddings to fp32 if not already
455
+ if teacher_embeddings.dtype != target_dtype:
456
+ logger.info(f"Converting teacher embeddings from {teacher_embeddings.dtype} to {target_dtype}")
457
+ teacher_embeddings = teacher_embeddings.to(target_dtype)
458
+
459
+ # Get dimensions
460
+ student_embedding_dim = student_model.embedding.shape[1]
461
+ teacher_embedding_dim = teacher_embeddings.shape[1]
462
+
463
+ logger.info(f"Student dims: {student_embedding_dim}, Teacher dims: {teacher_embedding_dim}")
464
+
465
+ # Project teacher embeddings if needed with high-precision PCA
466
+ if teacher_embedding_dim != student_embedding_dim:
467
+ from sklearn.decomposition import PCA
468
+
469
+ logger.info("Performing high-precision PCA projection for quality preservation...")
470
+ pca = PCA(n_components=student_embedding_dim)
471
+
472
+ # Use float64 for PCA computation to maximize precision
473
+ teacher_embeddings_np = teacher_embeddings.cpu().numpy().astype(np.float64)
474
+ teacher_embeddings_projected = pca.fit_transform(teacher_embeddings_np)
475
+
476
+ # Convert back to fp32 (always use fp32, never fp16)
477
+ teacher_embeddings = torch.tensor(
478
+ teacher_embeddings_projected.astype(np.float32),
479
+ dtype=target_dtype,
480
+ )
481
+ logger.info(f"PCA projection completed: {teacher_embeddings.shape} with dtype {target_dtype}")
482
+ logger.info(
483
+ f"PCA preserved variance ratio: {pca.explained_variance_ratio_[:5].sum():.4f} (first 5 components)"
484
+ )
485
+
486
+ # Create trainable model
487
+ trainable_model = FinetunableStaticModel.from_static_model(
488
+ model=student_model,
489
+ out_dim=student_embedding_dim,
490
+ )
491
+
492
+ # Force ALL model parameters to fp32 to ensure no precision loss
493
+ trainable_model = trainable_model.float()
494
+
495
+ # Additional explicit conversion of embedding weights to fp32
496
+ if hasattr(trainable_model, "embeddings") and hasattr(trainable_model.embeddings, "weight"):
497
+ trainable_model.embeddings.weight.data = trainable_model.embeddings.weight.data.to(target_dtype)
498
+
499
+ # Verify final model dtype after model2vec patch fix
500
+ actual_model_dtype = None
501
+ for param in trainable_model.parameters():
502
+ actual_model_dtype = param.dtype
503
+ break
504
+
505
+ logger.info(f"Model parameter dtype: {actual_model_dtype}")
506
+ logger.info(f"Embedding weight dtype: {trainable_model.embeddings.weight.dtype}")
507
+
508
+ # Ensure teacher embeddings are definitely in fp32
509
+ teacher_embeddings = teacher_embeddings.to(target_dtype)
510
+ logger.info(f"Final teacher embeddings dtype: {teacher_embeddings.dtype}")
511
+ logger.info(f"Final model parameter dtype: {actual_model_dtype}")
512
+
513
+ # Verify we're using fp32 throughout
514
+ if teacher_embeddings.dtype != target_dtype:
515
+ logger.warning(f"⚠️ Teacher embeddings not in {target_dtype}: {teacher_embeddings.dtype}")
516
+ if actual_model_dtype != target_dtype:
517
+ logger.warning(f"⚠️ Model parameters not in {target_dtype}: {actual_model_dtype}")
518
+
519
+ logger.info("βœ… Confirmed fp32 precision throughout the training pipeline")
520
+
521
+ # Tokenize texts
522
+ tokenized_texts = []
523
+ for text in training_texts:
524
+ tokens = trainable_model.tokenize([text])
525
+ if tokens.shape[1] > 0:
526
+ tokenized_texts.append(tokens[0].tolist())
527
+
528
+ # Prepare training data with explicit fp32 casting
529
+ targets = teacher_embeddings[: len(tokenized_texts)]
530
+
531
+ # Force targets to fp32 to maintain maximum precision
532
+ targets = targets.to(target_dtype)
533
+ logger.info(f"Cast targets to fp32: {targets.dtype}")
534
+
535
+ train_texts, val_texts, train_targets, val_targets = train_test_split(
536
+ tokenized_texts, targets, test_size=0.2, random_state=42
537
+ )
538
+
539
+ logger.info(f"Train targets dtype: {train_targets.dtype}")
540
+ logger.info(f"Val targets dtype: {val_targets.dtype}")
541
+
542
+ # Training setup
543
+ train_dataset = TextDataset(train_texts, train_targets)
544
+ val_dataset = TextDataset(val_texts, val_targets)
545
+
546
+ optimizer = optim.Adam(trainable_model.parameters(), lr=LEARNING_RATE)
547
+ mse_loss = nn.MSELoss()
548
+
549
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
550
+
551
+ try:
552
+ trainable_model = trainable_model.to(device)
553
+ logger.info(f"Training on {device}")
554
+ except torch.cuda.OutOfMemoryError:
555
+ logger.warning("GPU OOM loading training model, using CPU")
556
+ device = torch.device("cpu")
557
+ trainable_model = trainable_model.to(device)
558
+ if torch.cuda.is_available():
559
+ torch.cuda.empty_cache()
560
+
561
+ # Adaptive batch size for training
562
+ adaptive_batch_size = BATCH_SIZE
563
+
564
+ # Quality monitoring: compute embedding similarity before training
565
+ logger.info("πŸ” Quality monitoring: Computing pre-training teacher-student similarity...")
566
+ trainable_model.eval()
567
+ with torch.no_grad():
568
+ # Take a small sample of texts for quality measurement
569
+ sample_texts = training_texts[: min(5, len(training_texts))]
570
+ sample_tokens = trainable_model.tokenize(sample_texts)
571
+ sample_tokens = sample_tokens.to(device)
572
+
573
+ _, student_embeddings_before = trainable_model(sample_tokens)
574
+ sample_teacher_embeddings = targets[: len(sample_texts)].to(device)
575
+
576
+ # Compute average cosine similarity
577
+ similarities_before = []
578
+ for i in range(len(sample_texts)):
579
+ sim = torch.cosine_similarity(
580
+ student_embeddings_before[i].unsqueeze(0),
581
+ sample_teacher_embeddings[i].unsqueeze(0),
582
+ ).item()
583
+ similarities_before.append(sim)
584
+
585
+ avg_similarity_before = np.mean(similarities_before)
586
+ logger.info(f"πŸ“Š Pre-training average teacher-student similarity: {avg_similarity_before:.4f}")
587
+
588
+ # Training loop with validation
589
+ for epoch in range(epochs):
590
+ # Training phase
591
+ trainable_model.train()
592
+
593
+ # Try with current batch size, reduce if OOM
594
+ train_successful = False
595
+ while not train_successful and adaptive_batch_size >= 1:
596
+ try:
597
+ train_loader = train_dataset.to_dataloader(shuffle=True, batch_size=adaptive_batch_size)
598
+
599
+ epoch_loss = 0.0
600
+ num_batches = 0
601
+
602
+ for batch_idx, (tokens, targets_batch) in enumerate(train_loader):
603
+ batch_tokens = tokens.to(device)
604
+ batch_targets = targets_batch.to(device)
605
+
606
+ optimizer.zero_grad()
607
+ _, student_embeddings = trainable_model(batch_tokens)
608
+
609
+ # Debug dtype information on first batch
610
+ if batch_idx == 0:
611
+ logger.info(
612
+ f"Batch {batch_idx}: tokens shape {batch_tokens.shape}, dtype {batch_tokens.dtype}"
613
+ )
614
+ logger.info(
615
+ f"Batch {batch_idx}: targets shape {batch_targets.shape}, dtype {batch_targets.dtype}"
616
+ )
617
+ logger.info(
618
+ f"Batch {batch_idx}: student_embeddings shape {student_embeddings.shape}, dtype {student_embeddings.dtype}"
619
+ )
620
+
621
+ # Force both tensors to fp32 to avoid any precision loss
622
+ if student_embeddings.dtype != target_dtype:
623
+ logger.warning(
624
+ f"Student embeddings not in fp32: {student_embeddings.dtype}, converting to fp32"
625
+ )
626
+ student_embeddings = student_embeddings.to(target_dtype)
627
+ if batch_targets.dtype != target_dtype:
628
+ logger.info(f"Converting targets from {batch_targets.dtype} to fp32")
629
+ batch_targets = batch_targets.to(target_dtype)
630
+
631
+ try:
632
+ loss = mse_loss(student_embeddings, batch_targets)
633
+ loss.backward()
634
+ optimizer.step()
635
+ except RuntimeError as e:
636
+ if "expected scalar type" in str(e):
637
+ logger.exception("Dtype mismatch error occurred:")
638
+ logger.exception(
639
+ f"student_embeddings: {student_embeddings.shape}, {student_embeddings.dtype}"
640
+ )
641
+ logger.exception(f"batch_targets: {batch_targets.shape}, {batch_targets.dtype}")
642
+ logger.exception(
643
+ f"MSE loss input dtypes: {student_embeddings.dtype} vs {batch_targets.dtype}"
644
+ )
645
+ # Force explicit casting to fp32 for maximum precision
646
+ batch_targets = batch_targets.to(target_dtype)
647
+ student_embeddings = student_embeddings.to(target_dtype)
648
+ logger.info("Emergency dtype fix: forced both to fp32")
649
+ loss = mse_loss(student_embeddings, batch_targets)
650
+ loss.backward()
651
+ optimizer.step()
652
+ else:
653
+ raise
654
+
655
+ epoch_loss += loss.item()
656
+ num_batches += 1
657
+
658
+ # Save training checkpoint periodically
659
+ if checkpoint_manager and batch_idx % 100 == 0:
660
+ training_state = {
661
+ "epoch": epoch,
662
+ "batch": batch_idx,
663
+ "model_state": trainable_model.state_dict(),
664
+ "optimizer_state": optimizer.state_dict(),
665
+ "loss": epoch_loss / max(1, num_batches),
666
+ }
667
+ checkpoint_data = create_checkpoint_data("training", training_state, epoch)
668
+ checkpoint_manager.save_checkpoint("training", checkpoint_data, epoch)
669
+
670
+ train_successful = True
671
+
672
+ except torch.cuda.OutOfMemoryError:
673
+ logger.warning(
674
+ f"Training OOM with batch size {adaptive_batch_size}, reducing to {adaptive_batch_size // 2}"
675
+ )
676
+ adaptive_batch_size = max(1, adaptive_batch_size // 2)
677
+ if torch.cuda.is_available():
678
+ torch.cuda.empty_cache()
679
+
680
+ if not train_successful:
681
+ logger.error("Unable to train even with batch size 1, skipping training")
682
+ break
683
+
684
+ avg_train_loss = epoch_loss / num_batches if num_batches > 0 else 0.0
685
+
686
+ # Validation phase
687
+ trainable_model.eval()
688
+ val_loader = val_dataset.to_dataloader(shuffle=False, batch_size=adaptive_batch_size)
689
+ val_loss = 0.0
690
+ val_batches = 0
691
+
692
+ with torch.no_grad():
693
+ for tokens, targets_batch in val_loader:
694
+ batch_tokens = tokens.to(device)
695
+ batch_targets = targets_batch.to(device)
696
+
697
+ _, student_embeddings = trainable_model(batch_tokens)
698
+
699
+ # Force both tensors to fp32 to avoid any precision loss in validation
700
+ if student_embeddings.dtype != target_dtype:
701
+ student_embeddings = student_embeddings.to(target_dtype)
702
+ if batch_targets.dtype != target_dtype:
703
+ batch_targets = batch_targets.to(target_dtype)
704
+
705
+ loss = mse_loss(student_embeddings, batch_targets)
706
+ val_loss += loss.item()
707
+ val_batches += 1
708
+
709
+ avg_val_loss = val_loss / val_batches if val_batches > 0 else 0.0
710
+
711
+ logger.info(
712
+ f"Epoch {epoch + 1}/{epochs} - Train Loss: {avg_train_loss:.6f}, Val Loss: {avg_val_loss:.6f}, Batch Size: {adaptive_batch_size}"
713
+ )
714
+
715
+ # Save epoch checkpoint
716
+ if checkpoint_manager:
717
+ epoch_state = {
718
+ "epoch": epoch + 1,
719
+ "model_state": trainable_model.state_dict(),
720
+ "optimizer_state": optimizer.state_dict(),
721
+ "train_loss": avg_train_loss,
722
+ "val_loss": avg_val_loss,
723
+ }
724
+ checkpoint_data = create_checkpoint_data("epoch", epoch_state, epoch + 1)
725
+ checkpoint_manager.save_checkpoint("epoch", checkpoint_data, epoch + 1)
726
+
727
+ # Quality monitoring: compute embedding similarity after training
728
+ logger.info("πŸ” Quality monitoring: Computing post-training teacher-student similarity...")
729
+ trainable_model.eval()
730
+ with torch.no_grad():
731
+ # Use the same sample texts as before
732
+ sample_texts = training_texts[: min(5, len(training_texts))]
733
+ sample_tokens = trainable_model.tokenize(sample_texts)
734
+ sample_tokens = sample_tokens.to(device)
735
+
736
+ _, student_embeddings_after = trainable_model(sample_tokens)
737
+ sample_teacher_embeddings = targets[: len(sample_texts)].to(device)
738
+
739
+ # Compute average cosine similarity
740
+ similarities_after = []
741
+ for i in range(len(sample_texts)):
742
+ sim = torch.cosine_similarity(
743
+ student_embeddings_after[i].unsqueeze(0),
744
+ sample_teacher_embeddings[i].unsqueeze(0),
745
+ ).item()
746
+ similarities_after.append(sim)
747
+
748
+ avg_similarity_after = np.mean(similarities_after)
749
+ logger.info(f"πŸ“Š Post-training average teacher-student similarity: {avg_similarity_after:.4f}")
750
+
751
+ # Quality assessment
752
+ quality_change = avg_similarity_after - avg_similarity_before
753
+ logger.info(f"πŸ“ˆ Quality change: {quality_change:+.4f}")
754
+
755
+ if abs(quality_change) < 0.01:
756
+ logger.info("βœ… Quality well preserved during training!")
757
+ elif quality_change > 0:
758
+ logger.info("βœ… Quality improved during training!")
759
+ else:
760
+ logger.warning(f"⚠️ Quality degraded by {abs(quality_change):.4f} during training")
761
+
762
+ # Convert back to static model
763
+ refined_model = trainable_model.to_static_model()
764
+
765
+ # Save final refined model to beam volume
766
+ if model_manager:
767
+ # Save to temporary local directory first
768
+ temp_refined_path = Path("./temp_refined_save")
769
+ temp_refined_path.mkdir(exist_ok=True)
770
+ refined_model.save_pretrained(str(temp_refined_path))
771
+
772
+ # Upload to beam volume
773
+ model_manager.save_model("refined_model", str(temp_refined_path))
774
+
775
+ # Clean up temp directory
776
+ import shutil
777
+
778
+ shutil.rmtree(temp_refined_path, ignore_errors=True)
779
+
780
+ logger.info("πŸ’Ύ Saved refined model to beam volume")
781
+
782
+ logger.info("Code specialization training completed")
783
+ return refined_model
784
+
785
+ except Exception as e:
786
+ logger.warning(f"Code training failed: {e}")
787
+ return student_model
788
+
789
+
790
+ def apply_regularization(model: Any, weight: float = 0.01) -> Any:
791
+ """Apply light regularization with overflow protection."""
792
+ # Validate input
793
+ if model is None:
794
+ logger.error("Cannot apply regularization: model is None")
795
+ msg = "model cannot be None"
796
+ raise ValueError(msg)
797
+
798
+ if not hasattr(model, "embedding"):
799
+ logger.error(f"Cannot apply regularization: model of type {type(model)} does not have 'embedding' attribute")
800
+ msg = f"model must have 'embedding' attribute, got {type(model)}"
801
+ raise ValueError(msg)
802
+
803
+ logger.info(f"Applying regularization to model of type: {type(model)}")
804
+
805
+ try:
806
+ embeddings = model.embedding.copy()
807
+
808
+ # Check for extreme values and clip if necessary
809
+ max_val = np.abs(embeddings).max()
810
+ if max_val > 1e6: # Clip extremely large values
811
+ logger.warning(f"Large embedding values detected (max: {max_val:.2e}), clipping to prevent overflow")
812
+ embeddings = np.clip(embeddings, -1e6, 1e6)
813
+
814
+ # Apply regularization
815
+ regularized_embeddings = embeddings * (1.0 - weight)
816
+
817
+ # Stable normalization to prevent overflow
818
+ norms = np.linalg.norm(regularized_embeddings, axis=1, keepdims=True)
819
+
820
+ # Handle zero norms and potential overflow
821
+ norms = np.where(norms == 0, 1, norms)
822
+ norms = np.where(norms > 1e6, 1e6, norms) # Prevent extremely large norms
823
+
824
+ regularized_embeddings = regularized_embeddings / norms
825
+
826
+ # Create new model
827
+ from model2vec.model import StaticModel
828
+
829
+ regularized_model = StaticModel(
830
+ vectors=regularized_embeddings,
831
+ tokenizer=model.tokenizer,
832
+ config=model.config,
833
+ base_model_name=model.base_model_name,
834
+ language=model.language,
835
+ normalize=True,
836
+ )
837
+
838
+ logger.info("Regularization applied successfully")
839
+ return regularized_model
840
+
841
+ except Exception as e:
842
+ logger.warning(f"Regularization failed: {e}")
843
+ return model
844
+
845
+
846
+ def load_teacher_model_with_cache(
847
+ model_name: str,
848
+ output_dir: str,
849
+ device: str = "cuda",
850
+ resume: bool = True,
851
+ ) -> SentenceTransformer:
852
+ """Load teacher model with local caching to avoid re-downloading."""
853
+ cache_dir = Path(output_dir) / "teacher_model_cache"
854
+
855
+ # Check if cached model exists
856
+ if resume and cache_dir.exists():
857
+ try:
858
+ logger.info(f"Loading cached teacher model from {cache_dir}")
859
+ teacher_model = SentenceTransformer(str(cache_dir), device=device)
860
+
861
+ # Set optimized sequence length
862
+ max_seq_len = TEACHER_MODEL_CONFIG.get("max_seq_length", 8192)
863
+ if isinstance(max_seq_len, int):
864
+ teacher_model.max_seq_length = max_seq_len
865
+
866
+ logger.info("Successfully loaded cached teacher model")
867
+ return teacher_model
868
+ except Exception as e:
869
+ logger.warning(f"Failed to load cached teacher model: {e}")
870
+ logger.info("Will download fresh model")
871
+
872
+ # Download and cache the model
873
+ logger.info(f"Downloading teacher model {model_name} (this may take a while)")
874
+
875
+ # Prepare model kwargs with flash attention
876
+ model_kwargs = {
877
+ "torch_dtype": TEACHER_MODEL_CONFIG["torch_dtype"],
878
+ "device_map": TEACHER_MODEL_CONFIG["device_map"],
879
+ }
880
+
881
+ # Try to add flash attention if available
882
+ if TEACHER_MODEL_CONFIG.get("use_flash_attention", False):
883
+ try:
884
+ model_kwargs["attn_implementation"] = TEACHER_MODEL_CONFIG["attn_implementation"]
885
+ logger.info("Flash Attention 2 enabled")
886
+ except Exception as e:
887
+ logger.warning(f"Flash Attention not available, using default attention: {e}")
888
+
889
+ try:
890
+ teacher_model = SentenceTransformer(
891
+ model_name,
892
+ device=device,
893
+ trust_remote_code=bool(TEACHER_MODEL_CONFIG["trust_remote_code"]),
894
+ model_kwargs=model_kwargs,
895
+ )
896
+ except ImportError as e:
897
+ if "flash_attn" in str(e):
898
+ logger.warning("Flash Attention 2 not available, falling back to default attention")
899
+ # Remove flash attention from model_kwargs and retry
900
+ model_kwargs_fallback = {k: v for k, v in model_kwargs.items() if k != "attn_implementation"}
901
+ teacher_model = SentenceTransformer(
902
+ model_name,
903
+ device=device,
904
+ trust_remote_code=bool(TEACHER_MODEL_CONFIG["trust_remote_code"]),
905
+ model_kwargs=model_kwargs_fallback,
906
+ )
907
+ else:
908
+ raise
909
+
910
+ # Set optimized sequence length
911
+ max_seq_len = TEACHER_MODEL_CONFIG.get("max_seq_length", 8192)
912
+ if isinstance(max_seq_len, int):
913
+ teacher_model.max_seq_length = max_seq_len
914
+ logger.info(f"Set max_seq_length to {max_seq_len} for better performance")
915
+
916
+ # Cache the model for future use
917
+ try:
918
+ cache_dir.mkdir(parents=True, exist_ok=True)
919
+ teacher_model.save(str(cache_dir))
920
+ logger.info(f"Cached teacher model to {cache_dir}")
921
+ except Exception as e:
922
+ logger.warning(f"Failed to cache teacher model: {e}")
923
+ # Continue without caching
924
+
925
+ return teacher_model
926
+
927
+
928
+ def code_specialized_distillation(
929
+ model_name: str = MODEL_NAME,
930
+ output_dir: str = OUTPUT_DIR,
931
+ pca_dims: int = PCA_DIMS,
932
+ max_samples: int = MAX_TRAINING_SAMPLES,
933
+ resume: bool = True,
934
+ ) -> Any:
935
+ """Main code-specialized distillation function using CodeSearchNet dataset with checkpoint support."""
936
+ output_path = Path(output_dir)
937
+ output_path.mkdir(parents=True, exist_ok=True)
938
+
939
+ # Initialize Beam utilities
940
+ volume_mgr, checkpoint_mgr, model_mgr, eval_mgr = create_beam_utilities(VOLUME_NAME, VOLUME_PATH)
941
+
942
+ logger.info(f"Starting code-specialized distillation of {model_name}")
943
+ logger.info(f"Using CodeSearchNet dataset: {CODESEARCHNET_DATASET}")
944
+ logger.info(f"Resume mode: {resume}")
945
+
946
+ # GPU Diagnostics
947
+ logger.info("=== GPU DIAGNOSTICS ===")
948
+ logger.info(f"CUDA available: {torch.cuda.is_available()}")
949
+ if torch.cuda.is_available():
950
+ logger.info(f"CUDA version: {torch.version.cuda}")
951
+ logger.info(f"GPU count: {torch.cuda.device_count()}")
952
+ for i in range(torch.cuda.device_count()):
953
+ gpu_name = torch.cuda.get_device_name(i)
954
+ gpu_memory = torch.cuda.get_device_properties(i).total_memory / 1024**3
955
+ logger.info(f"GPU {i}: {gpu_name} ({gpu_memory:.1f} GB)")
956
+
957
+ # Current GPU memory
958
+ current_device = torch.cuda.current_device()
959
+ allocated = torch.cuda.memory_allocated(current_device) / 1024**3
960
+ total = torch.cuda.get_device_properties(current_device).total_memory / 1024**3
961
+ logger.info(f"Current GPU {current_device}: {allocated:.2f}GB allocated, {total:.1f}GB total")
962
+ else:
963
+ logger.warning("CUDA not available - will use CPU (much slower)")
964
+ logger.info("======================")
965
+
966
+ start_time = time.time()
967
+
968
+ # Step 1: Basic Model2Vec distillation with checkpoint support
969
+ logger.info("Step 1: Basic Model2Vec distillation...")
970
+
971
+ # Check for existing distilled model in beam volume
972
+ m2v_model = None
973
+ if resume:
974
+ # Check if model files exist directly in the volume root
975
+ try:
976
+ # Try to load from the volume root where the model was successfully saved
977
+ volume_root_path = Path(VOLUME_PATH)
978
+ if (volume_root_path / "config.json").exists() and (volume_root_path / "model.safetensors").exists():
979
+ logger.info("βœ… Found existing model files in volume root")
980
+ from model2vec.model import StaticModel
981
+
982
+ m2v_model = StaticModel.from_pretrained(str(volume_root_path))
983
+ logger.info("βœ… Successfully loaded existing distilled model from volume")
984
+ else:
985
+ logger.info("No existing model files found in volume root")
986
+ except Exception as e:
987
+ logger.warning(f"Failed to load existing model from volume: {e}")
988
+ m2v_model = None
989
+
990
+ if m2v_model is None:
991
+ # Clear GPU cache before starting
992
+ if torch.cuda.is_available():
993
+ torch.cuda.empty_cache()
994
+ current_device = torch.cuda.current_device()
995
+ allocated = torch.cuda.memory_allocated(current_device) / 1024**3
996
+ total = torch.cuda.get_device_properties(current_device).total_memory / 1024**3
997
+ logger.info(f"GPU memory before distillation: {allocated:.2f}GB allocated / {total:.1f}GB total")
998
+ else:
999
+ logger.info("Using CPU for distillation")
1000
+
1001
+ try:
1002
+ m2v_model = distill(
1003
+ model_name=model_name,
1004
+ pca_dims=pca_dims,
1005
+ apply_zipf=None,
1006
+ sif_coefficient=1e-4,
1007
+ trust_remote_code=True,
1008
+ )
1009
+ logger.info("Basic distillation completed with preserved precision")
1010
+
1011
+ # Validate the distilled model
1012
+ if m2v_model is None:
1013
+ msg = "Distillation returned None - this should not happen"
1014
+ raise ValueError(msg) from None
1015
+
1016
+ logger.info(f"Distilled model type: {type(m2v_model)}")
1017
+ logger.info(f"Distilled model has embedding attribute: {hasattr(m2v_model, 'embedding')}")
1018
+
1019
+ # Save the base distilled model - DISABLED due to recursive directory bug
1020
+ # model_mgr.save_model("base_distilled_model", str(output_path))
1021
+
1022
+ except torch.cuda.OutOfMemoryError:
1023
+ logger.warning("GPU OOM during distillation, clearing cache and retrying...")
1024
+ torch.cuda.empty_cache()
1025
+
1026
+ # Force CPU-only distillation if GPU fails
1027
+ os.environ["CUDA_VISIBLE_DEVICES"] = ""
1028
+
1029
+ logger.info("Retrying distillation on CPU...")
1030
+ m2v_model = distill(
1031
+ model_name=model_name,
1032
+ pca_dims=pca_dims,
1033
+ apply_zipf=None,
1034
+ sif_coefficient=1e-4,
1035
+ trust_remote_code=True,
1036
+ )
1037
+ logger.info("Basic distillation completed on CPU")
1038
+
1039
+ # Validate the distilled model
1040
+ if m2v_model is None:
1041
+ msg = "CPU distillation returned None - this should not happen"
1042
+ raise ValueError(msg) from None
1043
+
1044
+ logger.info(f"CPU distilled model type: {type(m2v_model)}")
1045
+ logger.info(f"CPU distilled model has embedding attribute: {hasattr(m2v_model, 'embedding')}")
1046
+
1047
+ # Save the base distilled model - DISABLED due to recursive directory bug
1048
+ # model_mgr.save_model("base_distilled_model", str(output_path))
1049
+
1050
+ except Exception:
1051
+ logger.exception("Distillation failed with error")
1052
+ raise
1053
+
1054
+ # Validate m2v_model before proceeding
1055
+ if m2v_model is None:
1056
+ msg = "m2v_model is None after distillation step - cannot proceed"
1057
+ raise ValueError(msg)
1058
+
1059
+ # Step 2: Load CodeSearchNet training data with resume
1060
+ logger.info("Step 2: Loading CodeSearchNet training data...")
1061
+ code_texts = load_codesearchnet_dataset_with_resume(max_samples, checkpoint_mgr)
1062
+
1063
+ if not code_texts:
1064
+ logger.warning("No code training data available, skipping code specialization")
1065
+ else:
1066
+ logger.info("Step 3: Code specialization training...")
1067
+
1068
+ # Check for existing refined model
1069
+ if resume:
1070
+ # Check if refined model exists in beam volume
1071
+ models = model_mgr.list_models()
1072
+ refined_model_exists = any(model["name"] == "refined_model" for model in models)
1073
+
1074
+ if refined_model_exists:
1075
+ # Download model to local path for loading
1076
+ temp_model_path = Path("./temp_refined_model")
1077
+ if model_mgr.load_model("refined_model", temp_model_path):
1078
+ try:
1079
+ from model2vec.model import StaticModel
1080
+
1081
+ refined_model = StaticModel.from_pretrained(str(temp_model_path / "refined_model"))
1082
+ logger.info("βœ… Resumed from existing refined model")
1083
+ m2v_model = refined_model
1084
+ # Clean up temp directory
1085
+ import shutil
1086
+
1087
+ shutil.rmtree(temp_model_path, ignore_errors=True)
1088
+ except Exception as e:
1089
+ logger.warning(f"Failed to load existing refined model: {e}")
1090
+ refined_model = None
1091
+ # Clean up temp directory
1092
+ import shutil
1093
+
1094
+ shutil.rmtree(temp_model_path, ignore_errors=True)
1095
+ else:
1096
+ refined_model = None
1097
+ else:
1098
+ refined_model = None
1099
+
1100
+ if refined_model is None:
1101
+ # Load teacher model with memory management
1102
+ try:
1103
+ device = "cuda" if torch.cuda.is_available() else "cpu"
1104
+ logger.info(f"Loading teacher model on {device} with optimized settings")
1105
+ logger.info(
1106
+ f"Using precision: {TEACHER_MODEL_CONFIG['precision']}, batch_size: {TEACHER_MODEL_CONFIG['batch_size']}"
1107
+ )
1108
+ logger.info("Attempting to enable Flash Attention 2 for maximum performance")
1109
+
1110
+ teacher_model = load_teacher_model_with_cache(model_name, output_dir, device=device, resume=resume)
1111
+
1112
+ # Generate teacher embeddings with checkpoints
1113
+ teacher_embeddings = generate_teacher_embeddings_with_checkpoints(
1114
+ teacher_model, code_texts, checkpoint_mgr
1115
+ )
1116
+
1117
+ # Refine with code training
1118
+ m2v_model = refine_with_code_training(
1119
+ m2v_model,
1120
+ code_texts,
1121
+ teacher_embeddings,
1122
+ epochs=TRAINING_EPOCHS,
1123
+ checkpoint_manager=checkpoint_mgr,
1124
+ model_manager=model_mgr,
1125
+ )
1126
+
1127
+ del teacher_model
1128
+ if torch.cuda.is_available():
1129
+ torch.cuda.empty_cache()
1130
+
1131
+ except torch.cuda.OutOfMemoryError:
1132
+ logger.warning("GPU OOM during code training, falling back to CPU...")
1133
+
1134
+ if torch.cuda.is_available():
1135
+ torch.cuda.empty_cache()
1136
+
1137
+ # Force CPU for teacher model with optimized settings (no flash attention on CPU)
1138
+ try:
1139
+ teacher_model = load_teacher_model_with_cache(
1140
+ model_name, output_dir, device="cpu", resume=resume
1141
+ )
1142
+ except ImportError as e:
1143
+ if "flash_attn" in str(e):
1144
+ logger.warning("Flash Attention 2 not available on CPU, using default attention")
1145
+ # Fallback without any special attention implementation
1146
+ teacher_model = load_teacher_model_with_cache(
1147
+ model_name, output_dir, device="cpu", resume=resume
1148
+ )
1149
+ else:
1150
+ raise
1151
+
1152
+ # Generate teacher embeddings on CPU with checkpoints
1153
+ teacher_embeddings = generate_teacher_embeddings_with_checkpoints(
1154
+ teacher_model, code_texts, checkpoint_mgr
1155
+ )
1156
+
1157
+ # Refine with code training on CPU
1158
+ m2v_model = refine_with_code_training(
1159
+ m2v_model,
1160
+ code_texts,
1161
+ teacher_embeddings,
1162
+ epochs=TRAINING_EPOCHS,
1163
+ checkpoint_manager=checkpoint_mgr,
1164
+ model_manager=model_mgr,
1165
+ )
1166
+
1167
+ del teacher_model
1168
+ else:
1169
+ # Fresh training without resume
1170
+ try:
1171
+ device = "cuda" if torch.cuda.is_available() else "cpu"
1172
+ logger.info(f"Loading teacher model on {device} with optimized settings")
1173
+ logger.info(
1174
+ f"Using precision: {TEACHER_MODEL_CONFIG['precision']}, batch_size: {TEACHER_MODEL_CONFIG['batch_size']}"
1175
+ )
1176
+ logger.info("Attempting to enable Flash Attention 2 for maximum performance")
1177
+
1178
+ teacher_model = load_teacher_model_with_cache(model_name, output_dir, device=device, resume=resume)
1179
+
1180
+ # Generate teacher embeddings with checkpoints
1181
+ teacher_embeddings = generate_teacher_embeddings_with_checkpoints(
1182
+ teacher_model, code_texts, checkpoint_mgr
1183
+ )
1184
+
1185
+ # Refine with code training
1186
+ m2v_model = refine_with_code_training(
1187
+ m2v_model,
1188
+ code_texts,
1189
+ teacher_embeddings,
1190
+ epochs=TRAINING_EPOCHS,
1191
+ checkpoint_manager=checkpoint_mgr,
1192
+ model_manager=model_mgr,
1193
+ )
1194
+
1195
+ del teacher_model
1196
+ if torch.cuda.is_available():
1197
+ torch.cuda.empty_cache()
1198
+
1199
+ except torch.cuda.OutOfMemoryError:
1200
+ logger.warning("GPU OOM during code training, falling back to CPU...")
1201
+
1202
+ if torch.cuda.is_available():
1203
+ torch.cuda.empty_cache()
1204
+
1205
+ # Force CPU for teacher model with optimized settings (no flash attention on CPU)
1206
+ try:
1207
+ teacher_model = load_teacher_model_with_cache(model_name, output_dir, device="cpu", resume=resume)
1208
+ except ImportError as e:
1209
+ if "flash_attn" in str(e):
1210
+ logger.warning("Flash Attention 2 not available on CPU, using default attention")
1211
+ # Fallback without any special attention implementation
1212
+ teacher_model = load_teacher_model_with_cache(
1213
+ model_name, output_dir, device="cpu", resume=resume
1214
+ )
1215
+ else:
1216
+ raise
1217
+
1218
+ # Generate teacher embeddings on CPU with checkpoints
1219
+ teacher_embeddings = generate_teacher_embeddings_with_checkpoints(
1220
+ teacher_model, code_texts, checkpoint_mgr
1221
+ )
1222
+
1223
+ # Refine with code training on CPU
1224
+ m2v_model = refine_with_code_training(
1225
+ m2v_model,
1226
+ code_texts,
1227
+ teacher_embeddings,
1228
+ epochs=TRAINING_EPOCHS,
1229
+ checkpoint_manager=checkpoint_mgr,
1230
+ model_manager=model_mgr,
1231
+ )
1232
+
1233
+ del teacher_model
1234
+
1235
+ # Step 4: Light regularization
1236
+ logger.info("Step 4: Applying regularization...")
1237
+ m2v_model = apply_regularization(m2v_model, REGULARIZATION_WEIGHT)
1238
+
1239
+ # Save final model
1240
+ logger.info("Saving code-specialized model...")
1241
+
1242
+ # Final validation before saving
1243
+ if m2v_model is None:
1244
+ msg = "Cannot save model: m2v_model is None"
1245
+ raise ValueError(msg)
1246
+
1247
+ if not hasattr(m2v_model, "save_pretrained"):
1248
+ msg = f"Cannot save model: m2v_model of type {type(m2v_model)} does not have save_pretrained method"
1249
+ raise ValueError(msg)
1250
+
1251
+ logger.info(f"Final model type: {type(m2v_model)}")
1252
+ logger.info(f"Final model has embedding attribute: {hasattr(m2v_model, 'embedding')}")
1253
+
1254
+ m2v_model.save_pretrained(str(output_path))
1255
+
1256
+ # Save final model to beam volume as well - DISABLED due to recursive directory bug
1257
+ # model_mgr.save_model("final_model", str(output_path))
1258
+
1259
+ total_time = time.time() - start_time
1260
+ logger.info(f"Code-specialized distillation completed in {total_time:.2f} seconds")
1261
+
1262
+ return m2v_model
1263
+
1264
+
1265
+ @function(
1266
+ gpu=GPU_NAME,
1267
+ volumes=[Volume(name=VOLUME_NAME, mount_path=VOLUME_PATH)],
1268
+ image=IMAGE,
1269
+ secrets=["HF_ACCESS_TOKEN"],
1270
+ env={
1271
+ "PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True,max_split_size_mb:512",
1272
+ "TOKENIZERS_PARALLELISM": "false",
1273
+ "CUDA_LAUNCH_BLOCKING": "0", # Allow async CUDA operations
1274
+ "TORCH_CUDNN_V8_API_ENABLED": "1", # Enable optimized cuDNN
1275
+ "OMP_NUM_THREADS": "8", # Limit CPU threads for better GPU utilization
1276
+ },
1277
+ timeout=3600 * 12, # 12 hours
1278
+ )
1279
+ def beam_code_distillation(
1280
+ model_name: str = MODEL_NAME,
1281
+ output_dir: str = OUTPUT_DIR,
1282
+ pca_dims: int = PCA_DIMS,
1283
+ max_samples: int = MAX_TRAINING_SAMPLES,
1284
+ resume: bool = True,
1285
+ ) -> Any:
1286
+ # Apply all patches from the patches directory
1287
+ try:
1288
+ from .patch_utils import apply_all_patches
1289
+
1290
+ logger.info("Applying all patches from patches directory...")
1291
+ patches_applied = apply_all_patches()
1292
+ logger.info(f"Successfully applied {patches_applied} patches")
1293
+ except Exception as e:
1294
+ logger.warning(f"Failed to apply patches: {e}. Continuing without patches.")
1295
+
1296
+ return code_specialized_distillation(
1297
+ model_name=model_name,
1298
+ output_dir=output_dir,
1299
+ pca_dims=pca_dims,
1300
+ max_samples=max_samples,
1301
+ resume=resume,
1302
+ )
1303
+
1304
+
1305
+ if __name__ == "__main__":
1306
+ code_specialized_distillation()
src/distiller/distill_simplified.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Simplified Code-Specialized Model2Vec Distillation Script.
3
+
4
+ This script implements a focused, simplified approach for creating code-specialized embeddings
5
+ using only the core Model2Vec distillation without additional fine-tuning that may degrade quality.
6
+
7
+ Can run locally or on Beam with the --use-beam flag.
8
+ """
9
+
10
+ import argparse
11
+ import json
12
+ import logging
13
+ import sys
14
+ import time
15
+ from pathlib import Path
16
+ from typing import Any
17
+
18
+ from beam import GpuType, Image, Volume, function
19
+ from model2vec.distill import distill
20
+
21
+ # =============================================================================
22
+ # SIMPLIFIED CONFIGURATION
23
+ # =============================================================================
24
+
25
+ # Use a code-specialized teacher model instead of general instruction model
26
+ # Ordered by success likelihood and performance:
27
+ CODE_TEACHER_MODELS = [
28
+ "sentence-transformers/all-MiniLM-L6-v2",
29
+ "sentence-transformers/all-mpnet-base-v2",
30
+ "microsoft/codebert-base",
31
+ "microsoft/graphcodebert-base",
32
+ "sentence-transformers/paraphrase-MiniLM-L6-v2",
33
+ "Alibaba-NLP/gte-Qwen2-7B-instruct",
34
+ ]
35
+
36
+ OUTPUT_BASE_DIR = "code_model2vec"
37
+
38
+ # Optimal Model2Vec parameters based on successful models
39
+ OPTIMAL_PCA_DIMS = 256 # Match other successful Model2Vec models
40
+ SIF_COEFFICIENT = 1e-3 # Slightly higher than default for code specialization
41
+ APPLY_ZIPF = True # Enable Zipf weighting for better word importance
42
+
43
+ # =============================================================================
44
+ # BEAM CONFIGURATION
45
+ # =============================================================================
46
+
47
+ GPU_NAME = GpuType.A100_40
48
+ VOLUME_NAME = "code_model2vec"
49
+ VOLUME_PATH = "./code_model2vec"
50
+ IMAGE = Image(python_version="python3.12").add_python_packages(
51
+ [
52
+ "torch>=2.7.0", # Install torch first
53
+ "transformers>=4.40.0", # Latest transformers with flash attention support
54
+ "lightning>=2.5.1.post0",
55
+ "model2vec[train]>=0.5.0",
56
+ "numpy>=1.26.4",
57
+ "scikit-learn>=1.6.1",
58
+ "sentence-transformers>=4.1.0",
59
+ "datasets>=3.2.0", # For evaluation
60
+ "pandas>=2.0.0",
61
+ "tqdm>=4.65.0",
62
+ ]
63
+ )
64
+
65
+ # =============================================================================
66
+
67
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
68
+ logger = logging.getLogger(__name__)
69
+
70
+ # Add beam utilities for proper model persistence
71
+ try:
72
+ from .beam_utils import (
73
+ create_beam_utilities,
74
+ )
75
+
76
+ BEAM_UTILS_AVAILABLE = True
77
+ except ImportError:
78
+ print("Beam utilities not available - models will only be saved locally")
79
+ BEAM_UTILS_AVAILABLE = False
80
+
81
+
82
+ def apply_local_patches() -> bool:
83
+ """Apply patches locally without requiring Beam utilities."""
84
+ try:
85
+ # Try using patch_utils if available
86
+ try:
87
+ from .patch_utils import apply_all_patches
88
+
89
+ patches_applied = apply_all_patches()
90
+ logger.info(f"Successfully applied {patches_applied} patches via patch_utils")
91
+ return True
92
+ except ImportError:
93
+ logger.warning("patch_utils not available, trying direct patching")
94
+
95
+ return False
96
+
97
+ except Exception as e:
98
+ logger.warning(f"Failed to apply patches: {e}")
99
+ return False
100
+
101
+
102
+ def simplified_code_distillation(
103
+ teacher_model: str,
104
+ output_dir: str,
105
+ pca_dims: int = OPTIMAL_PCA_DIMS,
106
+ ) -> Any:
107
+ """
108
+ Simplified code-specialized distillation using only core Model2Vec.
109
+
110
+ This approach:
111
+ 1. Uses a teacher model that already performs well on code tasks
112
+ 2. Applies optimal Model2Vec parameters
113
+ 3. Avoids additional training that may degrade quality
114
+ """
115
+ output_path = Path(output_dir)
116
+ output_path.mkdir(parents=True, exist_ok=True)
117
+
118
+ logger.info(f"Starting simplified distillation from {teacher_model}")
119
+ logger.info(f"Target dimensions: {pca_dims}")
120
+ logger.info(f"SIF coefficient: {SIF_COEFFICIENT}")
121
+ logger.info(f"Zipf weighting: {APPLY_ZIPF}")
122
+
123
+ start_time = time.time()
124
+
125
+ try:
126
+ # Perform distillation with optimal parameters
127
+ model = distill(
128
+ model_name=teacher_model,
129
+ pca_dims=pca_dims,
130
+ apply_zipf=APPLY_ZIPF,
131
+ sif_coefficient=SIF_COEFFICIENT,
132
+ trust_remote_code=True,
133
+ )
134
+
135
+ logger.info("βœ… Core distillation completed successfully")
136
+
137
+ # Save the model
138
+ model.save_pretrained(str(output_path))
139
+ logger.info(f"πŸ’Ύ Model saved to {output_path}")
140
+
141
+ # Log model info
142
+ logger.info(f"Model type: {type(model)}")
143
+ if hasattr(model, "embedding"):
144
+ logger.info(f"Embedding shape: {model.embedding.shape}")
145
+ logger.info(f"Embedding dtype: {model.embedding.dtype}")
146
+
147
+ total_time = time.time() - start_time
148
+ logger.info(f"πŸŽ‰ Simplified distillation completed in {total_time:.2f} seconds")
149
+ return model
150
+
151
+ except ValueError as e:
152
+ if "Number of tokens" in str(e) and "does not match number of vectors" in str(e):
153
+ logger.warning(f"⚠️ Token-vector mismatch with {teacher_model} - this is a Model2Vec library issue")
154
+ logger.warning(f"Error details: {e}")
155
+ logger.warning("πŸ’‘ This model has incompatible tokenization. Skipping...")
156
+ return None
157
+ raise
158
+ except Exception:
159
+ logger.exception("❌ Distillation failed")
160
+ return None
161
+
162
+
163
+ def core_distill_all_teachers(use_beam_utilities: bool = False) -> dict[str, Any]:
164
+ """
165
+ Core logic for distilling all teacher models.
166
+
167
+ Args:
168
+ use_beam_utilities: Whether to use Beam utilities for persistence
169
+
170
+ Returns:
171
+ Dictionary with distillation results
172
+ """
173
+ # Apply patches
174
+ logger.info("Applying all patches...")
175
+ patch_success = apply_local_patches()
176
+ if patch_success:
177
+ logger.info("Successfully applied patches")
178
+ else:
179
+ logger.warning("Failed to apply patches - Microsoft models may fail")
180
+
181
+ # Initialize Beam utilities if requested and available
182
+ volume_mgr = None
183
+ model_mgr = None
184
+ if use_beam_utilities and BEAM_UTILS_AVAILABLE:
185
+ try:
186
+ volume_mgr, _, model_mgr, _ = create_beam_utilities(VOLUME_NAME, VOLUME_PATH)
187
+ logger.info("βœ… Beam utilities initialized for model persistence")
188
+ except Exception as e:
189
+ logger.warning(f"Failed to initialize Beam utilities: {e}")
190
+ model_mgr = None
191
+
192
+ results = {}
193
+ successful_models = []
194
+
195
+ logger.info("πŸš€ Starting comprehensive teacher model distillation")
196
+ logger.info(f"πŸ“Š Processing {len(CODE_TEACHER_MODELS)} teacher models")
197
+
198
+ # Determine output base path
199
+ base_output_path = VOLUME_PATH if use_beam_utilities else OUTPUT_BASE_DIR
200
+
201
+ for teacher_model in CODE_TEACHER_MODELS:
202
+ try:
203
+ # Create output directory name based on teacher model
204
+ teacher_name = teacher_model.split("/")[-1].replace("-", "_")
205
+ output_dir = f"{base_output_path}/final/code_model2vec_{teacher_name}"
206
+
207
+ logger.info(f"\n{'=' * 60}")
208
+ logger.info(f"πŸ”„ Processing teacher model: {teacher_model}")
209
+ logger.info(f"πŸ“ Output directory: {output_dir}")
210
+ logger.info(f"{'=' * 60}")
211
+
212
+ # Check if model already exists
213
+ output_path = Path(output_dir)
214
+ if output_path.exists():
215
+ # Check for essential model files
216
+ has_config = (output_path / "config.json").exists()
217
+ has_model_file = any(
218
+ [
219
+ (output_path / "model.safetensors").exists(),
220
+ (output_path / "model.bin").exists(),
221
+ (output_path / "pytorch_model.bin").exists(),
222
+ ]
223
+ )
224
+
225
+ if has_config and has_model_file:
226
+ logger.info(f"βœ… Model {teacher_name} already exists - skipping distillation")
227
+
228
+ # Still record it as successful
229
+ model_info = {
230
+ "teacher_model": teacher_model,
231
+ "output_dir": output_dir,
232
+ "teacher_name": teacher_name,
233
+ "distillation_time": 0.0,
234
+ "status": "skipped_existing",
235
+ }
236
+
237
+ results[teacher_name] = model_info
238
+ successful_models.append(teacher_name)
239
+ logger.info(f"πŸ“ Using existing model at: {output_dir}")
240
+ continue
241
+
242
+ # Perform distillation
243
+ start_time = time.time()
244
+ model = simplified_code_distillation(
245
+ teacher_model=teacher_model,
246
+ output_dir=output_dir,
247
+ )
248
+ distill_time = time.time() - start_time
249
+
250
+ if model is not None:
251
+ logger.info(f"βœ… Distillation successful for {teacher_model}")
252
+
253
+ # Save to Beam volume for persistence if available
254
+ if model_mgr:
255
+ try:
256
+ # Save model to beam volume with teacher-specific name
257
+ beam_model_name = f"{teacher_name}_model"
258
+ model_mgr.save_model(beam_model_name, output_dir)
259
+ logger.info(f"πŸ’Ύ Saved {teacher_name} to Beam volume as {beam_model_name}")
260
+ except Exception as e:
261
+ logger.warning(f"Failed to save {teacher_name} to Beam volume: {e}")
262
+
263
+ # Store results
264
+ model_info = {
265
+ "teacher_model": teacher_model,
266
+ "output_dir": output_dir,
267
+ "teacher_name": teacher_name,
268
+ "distillation_time": distill_time,
269
+ "status": "success",
270
+ }
271
+
272
+ results[teacher_name] = model_info
273
+ successful_models.append(teacher_name)
274
+
275
+ logger.info(f"πŸ’Ύ Model saved to: {output_dir}")
276
+
277
+ except Exception as e:
278
+ logger.exception(f"❌ Failed with {teacher_model}")
279
+ results[teacher_model.split("/")[-1]] = {
280
+ "teacher_model": teacher_model,
281
+ "status": "failed",
282
+ "error": str(e),
283
+ }
284
+ continue
285
+
286
+ # Summary
287
+ if successful_models:
288
+ logger.info("\nπŸ† DISTILLATION COMPLETE!")
289
+ logger.info(f"πŸ“Š Successful models: {len(successful_models)}")
290
+
291
+ for model_name in successful_models:
292
+ model_info = results[model_name]
293
+ logger.info(f"βœ… {model_name}: {model_info['teacher_model']}")
294
+
295
+ # Save comprehensive results
296
+ results_summary = {
297
+ "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
298
+ "successful_models": successful_models,
299
+ "all_results": results,
300
+ "total_successful": len(successful_models),
301
+ "total_attempted": len(CODE_TEACHER_MODELS),
302
+ }
303
+
304
+ # Save results to file
305
+ results_file = Path(f"{base_output_path}/distillation_results.json")
306
+ results_file.parent.mkdir(parents=True, exist_ok=True)
307
+ with results_file.open("w") as f:
308
+ json.dump(results_summary, f, indent=2)
309
+
310
+ logger.info(f"πŸ“Š Results summary saved to: {results_file}")
311
+
312
+ return results_summary
313
+
314
+ logger.error("❌ No models succeeded")
315
+ msg = "All teacher models failed distillation"
316
+ raise RuntimeError(msg)
317
+
318
+
319
+ def run_local_distillation() -> dict[str, Any]:
320
+ """Run distillation locally without Beam."""
321
+ logger.info("πŸ–₯️ Running simplified distillation locally")
322
+ return core_distill_all_teachers(use_beam_utilities=False)
323
+
324
+
325
+ @function(
326
+ gpu=GPU_NAME,
327
+ volumes=[Volume(name=VOLUME_NAME, mount_path=VOLUME_PATH)],
328
+ image=IMAGE,
329
+ secrets=["HF_ACCESS_TOKEN"],
330
+ env={
331
+ "PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True,max_split_size_mb:512",
332
+ "TOKENIZERS_PARALLELISM": "false",
333
+ "CUDA_LAUNCH_BLOCKING": "0", # Allow async CUDA operations
334
+ "TORCH_CUDNN_V8_API_ENABLED": "1", # Enable optimized cuDNN
335
+ },
336
+ timeout=3600 * 12, # 12 hours
337
+ )
338
+ def beam_distill_all_teachers() -> dict[str, Any]:
339
+ """
340
+ Beam version: Try all teacher models and create distilled models from each.
341
+
342
+ Returns information about all models that were successfully created.
343
+ """
344
+ logger.info("☁️ Running simplified distillation on Beam")
345
+ return core_distill_all_teachers(use_beam_utilities=True)
346
+
347
+
348
+ def main() -> None:
349
+ """Main function with argument parsing."""
350
+ global OUTPUT_BASE_DIR # Declare global at the top # noqa: PLW0603
351
+
352
+ parser = argparse.ArgumentParser(
353
+ description="Simplified Code-Specialized Model2Vec Distillation",
354
+ formatter_class=argparse.RawDescriptionHelpFormatter,
355
+ epilog="""
356
+ Examples:
357
+ python -m src.distiller.distill_simplified # Run locally
358
+ python -m src.distiller.distill_simplified --use-beam # Run on Beam
359
+ distiller distill-simple # CLI shortcut (runs on Beam)
360
+ """,
361
+ )
362
+
363
+ parser.add_argument(
364
+ "--use-beam",
365
+ action="store_true",
366
+ help="Run on Beam instead of locally",
367
+ )
368
+
369
+ parser.add_argument(
370
+ "--output-dir",
371
+ type=str,
372
+ default=OUTPUT_BASE_DIR,
373
+ help=f"Output directory for models (default: {OUTPUT_BASE_DIR})",
374
+ )
375
+
376
+ args = parser.parse_args()
377
+
378
+ # Update output directory if specified
379
+ if args.output_dir != OUTPUT_BASE_DIR:
380
+ OUTPUT_BASE_DIR = args.output_dir
381
+
382
+ try:
383
+ if args.use_beam:
384
+ logger.info("πŸš€ Starting Beam execution...")
385
+ results = beam_distill_all_teachers()
386
+ else:
387
+ logger.info("πŸ–₯️ Starting local execution...")
388
+ results = run_local_distillation()
389
+
390
+ # Print final summary
391
+ print("\nπŸŽ‰ Distillation complete!")
392
+ print(f"πŸ“Š Successfully created {results['total_successful']} models")
393
+
394
+ if args.use_beam:
395
+ print(f"πŸ“ Models location: {VOLUME_PATH}/final/")
396
+ else:
397
+ print(f"πŸ“ Models location: {OUTPUT_BASE_DIR}/final/")
398
+
399
+ print("\nβœ… Created models:")
400
+ for model_name in results["successful_models"]:
401
+ model_info = results["all_results"][model_name]
402
+ print(f" β€’ {model_name} (from {model_info['teacher_model']})")
403
+
404
+ except KeyboardInterrupt:
405
+ logger.info("πŸ›‘ Distillation interrupted by user")
406
+ sys.exit(1)
407
+ except Exception:
408
+ logger.exception("❌ Distillation failed with error")
409
+ sys.exit(1)
410
+
411
+
412
+ if __name__ == "__main__":
413
+ main()
src/distiller/evaluate.py ADDED
@@ -0,0 +1,839 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CodeSearchNet Evaluation Script for Code-Specialized Embedding Models.
3
+
4
+ This script evaluates embedding models on code search tasks using the CodeSearchNet
5
+ dataset and methodology. It implements the same evaluation approach as the original
6
+ CodeSearchNet challenge, including NDCG and other information retrieval metrics.
7
+
8
+ Usage:
9
+ distiller evaluate # Run evaluation on all default models with Beam
10
+ """
11
+
12
+ import json
13
+ import logging
14
+ import time
15
+ from pathlib import Path
16
+ from typing import Any
17
+
18
+ import numpy as np
19
+ import pandas as pd
20
+ from beam import GpuType, Image, Volume, function
21
+ from datasets import Dataset, load_dataset
22
+ from sentence_transformers import SentenceTransformer
23
+ from sklearn.metrics.pairwise import cosine_similarity
24
+ from tqdm import tqdm
25
+
26
+ from .beam_utils import (
27
+ BeamCheckpointManager,
28
+ BeamEvaluationManager,
29
+ create_beam_utilities,
30
+ )
31
+
32
+ # Configure logging
33
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
34
+ logger = logging.getLogger(__name__)
35
+
36
+ # =============================================================================
37
+ # BEAM CONFIGURATION
38
+ # =============================================================================
39
+
40
+ GPU_NAME = GpuType.A100_40
41
+ VOLUME_NAME = "code_model2vec" # Same volume as distill_simplified.py
42
+ VOLUME_PATH = "./code_model2vec" # Same mount path as distill_simplified.py
43
+ EVALUATION_RESULTS_DIR = "evaluation_results" # Subdirectory within volume
44
+ EVALUATION_CACHE_DIR = "evaluation_cache" # Cache for datasets and models
45
+
46
+ IMAGE = Image(python_version="python3.12").add_python_packages(
47
+ [
48
+ "torch>=2.7.0",
49
+ "transformers>=4.40.0",
50
+ "datasets>=3.2.0",
51
+ "sentence-transformers>=4.1.0",
52
+ "model2vec[train]>=0.5.0",
53
+ "numpy>=1.26.4",
54
+ "scikit-learn>=1.6.1",
55
+ "pandas>=2.0.0",
56
+ "tqdm>=4.65.0",
57
+ ]
58
+ )
59
+
60
+ # =============================================================================
61
+ # CONFIGURATION
62
+ # =============================================================================
63
+
64
+ CODESEARCHNET_EVAL_DATASET = "code_search_net"
65
+ BATCH_SIZE = 32
66
+ DEFAULT_OUTPUT_DIR = "code_evaluation_results" # Local fallback directory
67
+ EVALUATION_LANGUAGES = ["python", "javascript", "java", "php", "ruby", "go"]
68
+
69
+ # Default models to evaluate (can be overridden via command line)
70
+ DEFAULT_EVALUATION_MODELS = [
71
+ # Established Code Models
72
+ "sentence-transformers/all-MiniLM-L6-v2",
73
+ "microsoft/codebert-base",
74
+ "microsoft/graphcodebert-base",
75
+ "huggingface/CodeBERTa-small-v1",
76
+ "sentence-transformers/all-mpnet-base-v2",
77
+ "sentence-transformers/all-MiniLM-L12-v2",
78
+ # Model2Vec & Efficiency Models (Direct Competitors)
79
+ "minishlab/potion-base-8M",
80
+ "minishlab/potion-retrieval-32M",
81
+ # Small Transformer-Based Code Models
82
+ "Salesforce/codet5-base",
83
+ ]
84
+
85
+ # =============================================================================
86
+ # CHECKPOINT CONFIGURATION
87
+ # =============================================================================
88
+
89
+ # Prevent conflicts with distill.py checkpoints by using different prefixes
90
+ EVAL_CHECKPOINT_PREFIX = "evaluation_checkpoints"
91
+ DATASET_CHECKPOINT_PREFIX = "dataset_cache"
92
+ MODEL_CACHE_PREFIX = "model_cache"
93
+
94
+ # =============================================================================
95
+ # CORE EVALUATION CLASSES
96
+ # =============================================================================
97
+
98
+
99
+ class CodeSearchNetEvaluator:
100
+ """Evaluator for CodeSearchNet-style code search tasks."""
101
+
102
+ def __init__(
103
+ self,
104
+ model_path: str,
105
+ model_name: str | None = None,
106
+ checkpoint_manager: BeamCheckpointManager | None = None,
107
+ eval_manager: BeamEvaluationManager | None = None,
108
+ ) -> None:
109
+ """Initialize the evaluator with a model and optional Beam utilities."""
110
+ self.model_path = model_path
111
+ self.model_name = model_name or Path(model_path).name
112
+ self.model: SentenceTransformer | None = None
113
+ self.checkpoint_manager = checkpoint_manager
114
+ self.eval_manager = eval_manager
115
+ self._load_model()
116
+
117
+ def _load_model(self) -> None:
118
+ """Load the embedding model with caching support."""
119
+ logger.info(f"Loading model from {self.model_path}")
120
+
121
+ # Check if we have a cached evaluation result for this model
122
+ if self.eval_manager:
123
+ cached_result = self.eval_manager.load_evaluation_results(self.model_name)
124
+ if cached_result:
125
+ logger.info(f"βœ… Found cached evaluation results for {self.model_name}")
126
+ # Note: We still need to load the model for new evaluations
127
+
128
+ try:
129
+ self.model = SentenceTransformer(self.model_path, trust_remote_code=True)
130
+ logger.info(f"Successfully loaded model: {self.model_name}")
131
+ except Exception:
132
+ logger.exception(f"Failed to load model from {self.model_path}")
133
+ raise
134
+
135
+ def encode_texts(self, texts: list[str], desc: str = "Encoding") -> np.ndarray:
136
+ """Encode texts into embeddings with batching."""
137
+ if self.model is None:
138
+ msg = "Model not loaded"
139
+ raise RuntimeError(msg)
140
+
141
+ embeddings = []
142
+
143
+ for i in tqdm(range(0, len(texts), BATCH_SIZE), desc=desc):
144
+ batch = texts[i : i + BATCH_SIZE]
145
+ batch_embeddings = self.model.encode(batch, convert_to_tensor=False, normalize_embeddings=True)
146
+ embeddings.append(batch_embeddings)
147
+
148
+ return np.vstack(embeddings)
149
+
150
+ def evaluate_language(self, language: str, max_queries: int = 1000) -> dict[str, Any]:
151
+ """Evaluate on a specific programming language with checkpoint support."""
152
+ logger.info(f"Evaluating on {language} language (max {max_queries} queries)")
153
+
154
+ # Check for existing evaluation checkpoint
155
+ if self.checkpoint_manager:
156
+ cached_result = self.checkpoint_manager.load_checkpoint(f"{EVAL_CHECKPOINT_PREFIX}_{language}", 0)
157
+ if cached_result and cached_result.get("data", {}).get("model_name") == self.model_name:
158
+ logger.info(f"βœ… Resuming from cached {language} evaluation")
159
+ return cached_result.get("data", {})
160
+
161
+ try:
162
+ # Load test split for the language
163
+ dataset = load_dataset(
164
+ CODESEARCHNET_EVAL_DATASET,
165
+ language,
166
+ split="test",
167
+ trust_remote_code=True,
168
+ )
169
+
170
+ # Ensure we have a Dataset object
171
+ if not isinstance(dataset, Dataset):
172
+ logger.error(f"Unexpected dataset type for {language}: {type(dataset)}")
173
+ return {}
174
+
175
+ # Sample queries for evaluation (to make it manageable)
176
+ if len(dataset) > max_queries:
177
+ rng = np.random.default_rng(42) # Use seeded generator for reproducibility
178
+ indices = rng.choice(len(dataset), max_queries, replace=False)
179
+ dataset = dataset.select(indices)
180
+
181
+ queries = []
182
+ codes = []
183
+ query_ids = []
184
+
185
+ for i, example in enumerate(dataset):
186
+ doc_string = example.get("func_documentation_string", "").strip()
187
+ code_string = example.get("func_code_string", "").strip()
188
+
189
+ if doc_string and code_string and len(doc_string.split()) >= 3:
190
+ queries.append(doc_string)
191
+ codes.append(code_string)
192
+ query_ids.append(f"{language}_{i}")
193
+
194
+ if len(queries) == 0:
195
+ logger.warning(f"No valid query-code pairs found for {language}")
196
+ return {}
197
+
198
+ logger.info(f"Found {len(queries)} valid query-code pairs for {language}")
199
+
200
+ # Encode queries and codes
201
+ query_embeddings = self.encode_texts(queries, f"Encoding {language} queries")
202
+ code_embeddings = self.encode_texts(codes, f"Encoding {language} codes")
203
+
204
+ # Compute similarities
205
+ similarities = cosine_similarity(query_embeddings, code_embeddings)
206
+
207
+ # Evaluate retrieval metrics
208
+ metrics = self._compute_retrieval_metrics(similarities)
209
+
210
+ result = {
211
+ "language": language,
212
+ "num_queries": len(queries),
213
+ "metrics": metrics,
214
+ "model_name": self.model_name,
215
+ }
216
+
217
+ # Save checkpoint
218
+ if self.checkpoint_manager:
219
+ checkpoint_data = {
220
+ "data": result,
221
+ "timestamp": time.time(),
222
+ "config": {
223
+ "language": language,
224
+ "max_queries": max_queries,
225
+ "model_name": self.model_name,
226
+ },
227
+ }
228
+ self.checkpoint_manager.save_checkpoint(f"{EVAL_CHECKPOINT_PREFIX}_{language}", checkpoint_data, 0)
229
+ logger.info(f"πŸ’Ύ Saved {language} evaluation checkpoint")
230
+
231
+ return result
232
+
233
+ except Exception:
234
+ logger.exception(f"Error evaluating {language}")
235
+ return {}
236
+
237
+ def _compute_retrieval_metrics(self, similarities: np.ndarray) -> dict[str, float]:
238
+ """Compute retrieval metrics like NDCG, MRR, etc."""
239
+ num_queries = similarities.shape[0]
240
+
241
+ # For each query, the correct code is at the same index (diagonal)
242
+ ranks = []
243
+ reciprocal_ranks = []
244
+ ndcg_scores = []
245
+
246
+ for i in range(num_queries):
247
+ # Get similarity scores for query i
248
+ scores = similarities[i]
249
+
250
+ # Rank all codes by similarity to query i
251
+ ranked_indices = np.argsort(scores)[::-1] # Descending order
252
+
253
+ # Find rank of the correct code (index i)
254
+ correct_rank = np.where(ranked_indices == i)[0][0] + 1 # 1-indexed
255
+ ranks.append(correct_rank)
256
+ reciprocal_ranks.append(1.0 / correct_rank)
257
+
258
+ # Compute NDCG@10
259
+ ndcg_scores.append(self._compute_ndcg(ranked_indices, i, k=10))
260
+
261
+ return {
262
+ "mrr": float(np.mean(reciprocal_ranks)),
263
+ "ndcg@1": float(
264
+ np.mean([self._compute_ndcg(np.argsort(similarities[i])[::-1], i, k=1) for i in range(num_queries)])
265
+ ),
266
+ "ndcg@5": float(
267
+ np.mean([self._compute_ndcg(np.argsort(similarities[i])[::-1], i, k=5) for i in range(num_queries)])
268
+ ),
269
+ "ndcg@10": float(np.mean(ndcg_scores)),
270
+ "recall@1": float(np.mean([1.0 if rank == 1 else 0.0 for rank in ranks])),
271
+ "recall@5": float(np.mean([1.0 if rank <= 5 else 0.0 for rank in ranks])),
272
+ "recall@10": float(np.mean([1.0 if rank <= 10 else 0.0 for rank in ranks])),
273
+ "mean_rank": float(np.mean(ranks)),
274
+ "median_rank": float(np.median(ranks)),
275
+ }
276
+
277
+ def _compute_ndcg(self, ranked_indices: np.ndarray, correct_idx: int, k: int) -> float:
278
+ """Compute NDCG@k for a single query."""
279
+ if k == 0:
280
+ return 0.0
281
+
282
+ # Find position of correct item in top-k
283
+ top_k = ranked_indices[:k]
284
+ if correct_idx in top_k:
285
+ position = np.where(top_k == correct_idx)[0][0]
286
+ return 1.0 / np.log2(position + 2) # +2 because log2(1) is 0
287
+ return 0.0
288
+
289
+ def evaluate_all_languages(
290
+ self, max_queries_per_lang: int = 1000, languages: list[str] | None = None
291
+ ) -> dict[str, Any]:
292
+ """Evaluate on all supported programming languages with comprehensive result saving."""
293
+ if languages is None:
294
+ languages = EVALUATION_LANGUAGES
295
+
296
+ logger.info(f"Starting evaluation on all languages for model: {self.model_name}")
297
+
298
+ # Check for existing comprehensive evaluation results
299
+ if self.eval_manager:
300
+ cached_comprehensive = self.eval_manager.load_evaluation_results(self.model_name)
301
+ if cached_comprehensive:
302
+ logger.info(f"βœ… Found comprehensive cached evaluation for {self.model_name}")
303
+ return cached_comprehensive
304
+
305
+ start_time = time.time()
306
+
307
+ results: dict[str, Any] = {
308
+ "model_name": self.model_name,
309
+ "model_path": self.model_path,
310
+ "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
311
+ "languages": {},
312
+ "overall": {},
313
+ }
314
+
315
+ all_metrics = []
316
+
317
+ for language in languages:
318
+ logger.info(f"Evaluating {language}...")
319
+ lang_results = self.evaluate_language(language, max_queries_per_lang)
320
+
321
+ if lang_results:
322
+ results["languages"][language] = lang_results
323
+ all_metrics.append(lang_results["metrics"])
324
+ else:
325
+ logger.warning(f"Skipping {language} due to evaluation error")
326
+
327
+ # Compute overall metrics (average across languages)
328
+ if all_metrics:
329
+ overall_metrics = {}
330
+ for metric_name in all_metrics[0]:
331
+ values = [m[metric_name] for m in all_metrics if metric_name in m]
332
+ overall_metrics[metric_name] = np.mean(values)
333
+
334
+ results["overall"] = overall_metrics
335
+
336
+ total_time = time.time() - start_time
337
+ results["evaluation_time_seconds"] = total_time
338
+
339
+ # Save comprehensive results to Beam volume
340
+ if self.eval_manager:
341
+ self.eval_manager.save_evaluation_results(self.model_name, results)
342
+ logger.info("πŸ’Ύ Saved comprehensive evaluation results to Beam volume")
343
+
344
+ logger.info(f"Evaluation completed in {total_time:.2f} seconds")
345
+ return results
346
+
347
+
348
+ def load_peer_models(peers_file: str) -> list[tuple[str, str]]:
349
+ """Load peer models from CSV file."""
350
+ try:
351
+ df = pd.read_csv(peers_file)
352
+ models = []
353
+ for _, row in df.iterrows():
354
+ model_name = row.get("model_name", row.get("Model", ""))
355
+ model_path = row.get("model_path", row.get("Path", model_name))
356
+ if model_name:
357
+ models.append((model_name, model_path))
358
+ logger.info(f"Loaded {len(models)} peer models from {peers_file}")
359
+ return models
360
+ except Exception:
361
+ logger.exception("Error loading peer models from {peers_file}")
362
+ return []
363
+
364
+
365
+ def save_results(
366
+ results: dict[str, Any],
367
+ output_dir: str,
368
+ model_name: str,
369
+ eval_manager: BeamEvaluationManager | None = None,
370
+ volume_results_dir: Path | None = None,
371
+ ) -> None:
372
+ """Save evaluation results to JSON file with Beam volume support."""
373
+ # Save to Beam volume if available
374
+ if volume_results_dir:
375
+ volume_output_path = volume_results_dir / f"codesearchnet_eval_{model_name}.json"
376
+ try:
377
+ with volume_output_path.open("w") as f:
378
+ json.dump(results, f, indent=2, default=str)
379
+ logger.info(f"πŸ’Ύ Results saved to Beam volume: {volume_output_path}")
380
+ except Exception as e:
381
+ logger.warning(f"⚠️ Failed to save to Beam volume: {e}")
382
+
383
+ # Also try eval_manager if available (for compatibility)
384
+ if eval_manager:
385
+ success = eval_manager.save_evaluation_results(model_name, results)
386
+ if success:
387
+ logger.info(f"πŸ’Ύ Results also saved via eval_manager for {model_name}")
388
+ else:
389
+ logger.warning(f"⚠️ Failed to save via eval_manager for {model_name}")
390
+
391
+ # Always save local backup
392
+ output_path = Path(output_dir)
393
+ output_path.mkdir(parents=True, exist_ok=True)
394
+
395
+ # Clean model name for filename
396
+ safe_name = "".join(c for c in model_name if c.isalnum() or c in ("-", "_", "."))
397
+ filename = f"codesearchnet_eval_{safe_name}.json"
398
+ filepath = output_path / filename
399
+
400
+ with Path(filepath).open("w") as f:
401
+ json.dump(results, f, indent=2, default=str)
402
+
403
+ logger.info(f"πŸ“„ Local backup saved to {filepath}")
404
+
405
+
406
+ def print_results_summary(results: dict[str, Any]) -> None:
407
+ """Print a summary of evaluation results."""
408
+ model_name = results["model_name"]
409
+ overall = results.get("overall", {})
410
+
411
+ print(f"\n{'=' * 60}")
412
+ print(f"CodeSearchNet Evaluation Results: {model_name}")
413
+ print(f"{'=' * 60}")
414
+
415
+ if overall:
416
+ print("\nOverall Metrics (averaged across languages):")
417
+ print(f" MRR: {overall.get('mrr', 0):.4f}")
418
+ print(f" NDCG@1: {overall.get('ndcg@1', 0):.4f}")
419
+ print(f" NDCG@5: {overall.get('ndcg@5', 0):.4f}")
420
+ print(f" NDCG@10: {overall.get('ndcg@10', 0):.4f}")
421
+ print(f" Recall@1: {overall.get('recall@1', 0):.4f}")
422
+ print(f" Recall@5: {overall.get('recall@5', 0):.4f}")
423
+ print(f" Recall@10: {overall.get('recall@10', 0):.4f}")
424
+
425
+ print("\nPer-Language Results:")
426
+ for lang, lang_results in results.get("languages", {}).items():
427
+ metrics = lang_results.get("metrics", {})
428
+ print(
429
+ f" {lang:12s}: MRR={metrics.get('mrr', 0):.3f}, "
430
+ f"NDCG@10={metrics.get('ndcg@10', 0):.3f}, "
431
+ f"Recall@5={metrics.get('recall@5', 0):.3f}"
432
+ )
433
+
434
+
435
+ def create_comparison_report(all_results: list[dict[str, Any]], output_dir: str) -> None:
436
+ """Create a comparison report across all evaluated models."""
437
+ if not all_results:
438
+ return
439
+
440
+ output_path = Path(output_dir)
441
+
442
+ # Create comparison DataFrame
443
+ comparison_data = []
444
+ for results in all_results:
445
+ overall = results.get("overall", {})
446
+ row = {
447
+ "Model": results["model_name"],
448
+ "MRR": overall.get("mrr", 0),
449
+ "NDCG@1": overall.get("ndcg@1", 0),
450
+ "NDCG@5": overall.get("ndcg@5", 0),
451
+ "NDCG@10": overall.get("ndcg@10", 0),
452
+ "Recall@1": overall.get("recall@1", 0),
453
+ "Recall@5": overall.get("recall@5", 0),
454
+ "Recall@10": overall.get("recall@10", 0),
455
+ "Mean Rank": overall.get("mean_rank", 0),
456
+ }
457
+ comparison_data.append(row)
458
+
459
+ df = pd.DataFrame(comparison_data)
460
+ df = df.sort_values("NDCG@10", ascending=False) # Sort by NDCG@10
461
+
462
+ # Save to CSV
463
+ csv_path = output_path / "codesearchnet_comparison.csv"
464
+ df.to_csv(csv_path, index=False, float_format="%.4f")
465
+ logger.info(f"Comparison report saved to {csv_path}")
466
+
467
+ # Print comparison table
468
+ print(f"\n{'=' * 80}")
469
+ print("CodeSearchNet Model Comparison")
470
+ print(f"{'=' * 80}")
471
+ print(df.to_string(index=False, float_format="%.4f"))
472
+
473
+
474
+ def beam_evaluate_models(
475
+ models: list[str],
476
+ max_queries: int = 1000,
477
+ languages: list[str] | None = None,
478
+ output_dir: str = DEFAULT_OUTPUT_DIR,
479
+ volume_name: str = VOLUME_NAME,
480
+ mount_path: str = VOLUME_PATH,
481
+ ) -> list[dict[str, Any]]:
482
+ """Main evaluation function for Beam execution with checkpoint support."""
483
+ logger.info("πŸš€ Starting Beam-powered CodeSearchNet evaluation")
484
+ logger.info(f"πŸ“Š Evaluating {len(models)} models on {len(languages or EVALUATION_LANGUAGES)} languages")
485
+
486
+ # Initialize Beam utilities
487
+ volume_mgr, checkpoint_mgr, model_mgr, eval_mgr = create_beam_utilities(volume_name, mount_path)
488
+
489
+ # Create evaluation results directory in volume
490
+ results_dir = Path(mount_path) / EVALUATION_RESULTS_DIR
491
+ results_dir.mkdir(parents=True, exist_ok=True)
492
+
493
+ logger.info(f"πŸ“ Using Beam volume: {volume_name} at {mount_path}")
494
+ logger.info(f"πŸ’Ύ Evaluation results directory: {results_dir}")
495
+
496
+ all_results = []
497
+ skipped_models = []
498
+
499
+ for model_path in models:
500
+ model_name = Path(model_path).name
501
+
502
+ # Check for existing evaluation results
503
+ existing_result_file = results_dir / f"codesearchnet_eval_{model_name}.json"
504
+ if existing_result_file.exists():
505
+ logger.info(f"βœ… Model {model_name} already evaluated - loading existing results")
506
+ try:
507
+ with existing_result_file.open("r") as f:
508
+ existing_results = json.load(f)
509
+ all_results.append(existing_results)
510
+ skipped_models.append(model_name)
511
+ continue
512
+ except Exception as e:
513
+ logger.warning(f"⚠️ Failed to load existing results for {model_name}: {e}")
514
+ # Continue with evaluation if loading fails
515
+
516
+ logger.info(f"\n{'=' * 60}")
517
+ logger.info(f"πŸ” Evaluating model: {model_name}")
518
+ logger.info(f"πŸ“‚ Path: {model_path}")
519
+ logger.info(f"{'=' * 60}")
520
+
521
+ try:
522
+ # Distinguish between local paths and HuggingFace model names
523
+ is_huggingface_model = (
524
+ "/" in model_path and not model_path.startswith("/") and not Path(model_path).exists()
525
+ )
526
+
527
+ if is_huggingface_model:
528
+ # This is a HuggingFace model name - pass directly to evaluator
529
+ logger.info(f"πŸ“₯ Loading HuggingFace model: {model_path}")
530
+ evaluator = CodeSearchNetEvaluator(
531
+ model_path,
532
+ model_name,
533
+ checkpoint_manager=checkpoint_mgr,
534
+ eval_manager=eval_mgr,
535
+ )
536
+ else:
537
+ # This is a local path - check if it exists in Beam volume
538
+ actual_model_path = model_path # Default to original path
539
+ if not Path(model_path).exists() and not model_path.startswith("/"):
540
+ # Try to load from Beam volume
541
+ local_model_path = Path(mount_path) / MODEL_CACHE_PREFIX / model_name
542
+ logger.info(f"πŸ” Trying to load {model_name} from Beam volume: {local_model_path}")
543
+ if model_mgr.load_model(model_name, local_model_path.parent):
544
+ actual_model_path = str(local_model_path)
545
+ logger.info(f"βœ… Loaded model from Beam volume: {actual_model_path}")
546
+ else:
547
+ logger.warning(f"⚠️ Model not found locally or in Beam volume: {model_name}")
548
+ continue
549
+
550
+ evaluator = CodeSearchNetEvaluator(
551
+ actual_model_path,
552
+ model_name,
553
+ checkpoint_manager=checkpoint_mgr,
554
+ eval_manager=eval_mgr,
555
+ )
556
+
557
+ results = evaluator.evaluate_all_languages(max_queries, languages)
558
+
559
+ # Save results with Beam support
560
+ save_results(results, output_dir, model_name, eval_mgr, results_dir)
561
+
562
+ # Print summary
563
+ print_results_summary(results)
564
+
565
+ all_results.append(results)
566
+
567
+ except Exception:
568
+ logger.exception(f"❌ Failed to evaluate {model_name}")
569
+ continue
570
+
571
+ # Create comparison report in Beam volume
572
+ if len(all_results) > 1:
573
+ comparison_dir = Path(mount_path) / EVALUATION_RESULTS_DIR / "comparisons"
574
+ comparison_dir.mkdir(parents=True, exist_ok=True)
575
+ create_comparison_report(all_results, str(comparison_dir))
576
+ logger.info(f"πŸ“Š Comparison report saved to Beam volume: {comparison_dir}")
577
+
578
+ # Log summary of what was done
579
+ newly_evaluated = len(all_results) - len(skipped_models)
580
+ logger.info("\nβœ… Beam evaluation complete!")
581
+ logger.info(f"πŸ“Š Newly evaluated: {newly_evaluated} models")
582
+ logger.info(f"⏭️ Skipped (already done): {len(skipped_models)} models")
583
+ logger.info(f"πŸ“ Total results: {len(all_results)} models")
584
+ logger.info(f"πŸ’Ύ Results available in Beam volume: {volume_name}")
585
+
586
+ if skipped_models:
587
+ logger.info(f"⏭️ Skipped models: {', '.join(skipped_models)}")
588
+
589
+ return all_results
590
+
591
+
592
+ @function(
593
+ gpu=GPU_NAME,
594
+ volumes=[Volume(name=VOLUME_NAME, mount_path=VOLUME_PATH)],
595
+ image=IMAGE,
596
+ secrets=["HF_ACCESS_TOKEN"],
597
+ env={
598
+ "TOKENIZERS_PARALLELISM": "false",
599
+ "CUDA_LAUNCH_BLOCKING": "0",
600
+ },
601
+ timeout=3600 * 6, # 6 hours for evaluation
602
+ )
603
+ def main(skip_third_party: bool = False) -> None:
604
+ """Main evaluation function - runs all default models on Beam."""
605
+ logger.info("πŸš€ Starting comprehensive CodeSearchNet evaluation on Beam")
606
+
607
+ # Use default models or skip them based on flag
608
+ if skip_third_party:
609
+ logger.info("⏭️ Skipping 3rd party models - evaluating only simplified distillation models")
610
+ models = []
611
+ else:
612
+ logger.info("πŸ“Š Including 3rd party peer models for comparison")
613
+ models = DEFAULT_EVALUATION_MODELS.copy()
614
+
615
+ # Discover simplified distillation models in the current directory
616
+ logger.info("πŸ” Discovering simplified distillation models...")
617
+ discovered_models = discover_simplified_models(".")
618
+
619
+ # Add discovered models (they're already sorted alphabetically)
620
+ if discovered_models:
621
+ logger.info(f"βœ… Found {len(discovered_models)} simplified models:")
622
+ for model_path in discovered_models:
623
+ models.append(model_path)
624
+ logger.info(f" πŸ“ {model_path}")
625
+ else:
626
+ logger.warning("⚠️ No simplified distillation models found")
627
+ if skip_third_party:
628
+ logger.error("❌ No models to evaluate! Either create simplified models or include 3rd party models.")
629
+ return
630
+
631
+ logger.info(f"πŸ“Š Evaluating {len(models)} models:")
632
+ for i, model in enumerate(models, 1):
633
+ logger.info(f" {i}. {model}")
634
+
635
+ logger.info("\nπŸ’‘ Checkpoint Info:")
636
+ logger.info(" - Already evaluated models will be skipped")
637
+ logger.info(" - Results are saved persistently to Beam volume")
638
+
639
+ # Run comprehensive evaluation using Beam utilities
640
+ results = beam_evaluate_models(
641
+ models=models,
642
+ max_queries=1000,
643
+ languages=EVALUATION_LANGUAGES,
644
+ output_dir=str(Path(VOLUME_PATH) / EVALUATION_RESULTS_DIR),
645
+ volume_name=VOLUME_NAME,
646
+ mount_path=VOLUME_PATH,
647
+ )
648
+
649
+ # Print final summary
650
+ print("\n🎯 Evaluation Summary:")
651
+ print(f"πŸ“Š Total models processed: {len(results)}")
652
+ print(f"πŸ’Ύ Results saved to Beam volume: {VOLUME_NAME}")
653
+ print(f"πŸ“ Directory: {EVALUATION_RESULTS_DIR}")
654
+ if skip_third_party:
655
+ print("⏭️ 3rd party models were skipped")
656
+ print("\nπŸ” To view analysis:")
657
+ print(" beam run src.distiller.analyze:beam_analysis")
658
+ print("\nπŸ“ˆ To run evaluations again:")
659
+ print(" distiller evaluate (will skip already completed models)")
660
+ print(" distiller evaluate --skip-third-party (evaluate only simplified models)")
661
+
662
+
663
+ def discover_simplified_models(base_path: str = ".") -> list[str]:
664
+ """
665
+ Discover all simplified distillation models in the correct directory.
666
+
667
+ Looks for directories matching the pattern: ./code_model2vec/final/code_model2vec_*
668
+ """
669
+ discovered_models: list[str] = []
670
+
671
+ # Look in the correct location where distill_simplified.py saves models
672
+ models_dir = Path(base_path) / "code_model2vec" / "final"
673
+
674
+ if not models_dir.exists():
675
+ logger.warning(f"Models directory not found: {models_dir}")
676
+ return discovered_models
677
+
678
+ # Look for simplified model directories with the updated pattern
679
+ pattern = "code_model2vec_*"
680
+ for model_dir in models_dir.glob(pattern):
681
+ if model_dir.is_dir() and (model_dir / "config.json").exists():
682
+ discovered_models.append(str(model_dir))
683
+ logger.info(f"πŸ” Discovered simplified model: {model_dir}")
684
+
685
+ # Sort alphabetically for consistent ordering
686
+ discovered_models.sort()
687
+
688
+ return discovered_models
689
+
690
+
691
+ @function(
692
+ gpu=GPU_NAME,
693
+ volumes=[Volume(name=VOLUME_NAME, mount_path=VOLUME_PATH)],
694
+ image=IMAGE,
695
+ secrets=["HF_ACCESS_TOKEN"],
696
+ env={
697
+ "TOKENIZERS_PARALLELISM": "false",
698
+ "CUDA_LAUNCH_BLOCKING": "0",
699
+ },
700
+ timeout=3600 * 6, # 6 hours for evaluation
701
+ )
702
+ def evaluate_simplified_only() -> None:
703
+ """Evaluate only simplified distillation models, skipping 3rd party models."""
704
+ main(skip_third_party=True)
705
+
706
+
707
+ def run_local_evaluation(
708
+ models: list[str] | None = None,
709
+ max_queries: int = 1000,
710
+ languages: list[str] | None = None,
711
+ output_dir: str = DEFAULT_OUTPUT_DIR,
712
+ ) -> list[dict[str, Any]]:
713
+ """Main evaluation function for local execution without Beam utilities."""
714
+ logger.info("πŸ–₯️ Running CodeSearchNet evaluation locally")
715
+
716
+ if models is None:
717
+ models = DEFAULT_EVALUATION_MODELS.copy()
718
+
719
+ # Discover simplified distillation models in the current directory
720
+ logger.info("πŸ” Discovering simplified distillation models...")
721
+ discovered_models = discover_simplified_models(".")
722
+
723
+ # Add discovered models
724
+ if discovered_models:
725
+ logger.info(f"βœ… Found {len(discovered_models)} simplified models:")
726
+ for model_path in discovered_models:
727
+ models.append(model_path)
728
+ logger.info(f" πŸ“ {model_path}")
729
+ else:
730
+ logger.warning("⚠️ No simplified distillation models found")
731
+
732
+ if languages is None:
733
+ languages = EVALUATION_LANGUAGES
734
+
735
+ logger.info(f"πŸ“Š Evaluating {len(models)} models on {len(languages)} languages")
736
+ logger.info(f"πŸ“ Using local output directory: {output_dir}")
737
+
738
+ # Create local output directory
739
+ output_path = Path(output_dir)
740
+ output_path.mkdir(parents=True, exist_ok=True)
741
+
742
+ all_results = []
743
+ skipped_models = []
744
+
745
+ for model_path in models:
746
+ model_name = Path(model_path).name
747
+
748
+ # Check for existing evaluation results locally
749
+ safe_name = "".join(c for c in model_name if c.isalnum() or c in ("-", "_", "."))
750
+ result_file = output_path / f"codesearchnet_eval_{safe_name}.json"
751
+
752
+ if result_file.exists():
753
+ logger.info(f"βœ… Model {model_name} already evaluated - loading existing results")
754
+ try:
755
+ with result_file.open("r") as f:
756
+ existing_results = json.load(f)
757
+ all_results.append(existing_results)
758
+ skipped_models.append(model_name)
759
+ continue
760
+ except Exception as e:
761
+ logger.warning(f"⚠️ Failed to load existing results for {model_name}: {e}")
762
+
763
+ logger.info(f"\n{'=' * 60}")
764
+ logger.info(f"πŸ” Evaluating model: {model_name}")
765
+ logger.info(f"πŸ“‚ Path: {model_path}")
766
+ logger.info(f"{'=' * 60}")
767
+
768
+ try:
769
+ # Create evaluator without Beam utilities (no checkpointing)
770
+ evaluator = CodeSearchNetEvaluator(
771
+ model_path,
772
+ model_name,
773
+ checkpoint_manager=None, # No checkpointing for local evaluation
774
+ eval_manager=None,
775
+ )
776
+
777
+ results = evaluator.evaluate_all_languages(max_queries, languages)
778
+
779
+ # Save results locally only
780
+ save_results(results, output_dir, model_name, eval_manager=None, volume_results_dir=None)
781
+
782
+ # Print summary
783
+ print_results_summary(results)
784
+
785
+ all_results.append(results)
786
+
787
+ except Exception:
788
+ logger.exception(f"❌ Failed to evaluate {model_name}")
789
+ continue
790
+
791
+ # Create comparison report locally
792
+ if len(all_results) > 1:
793
+ create_comparison_report(all_results, output_dir)
794
+ logger.info(f"πŸ“Š Comparison report saved locally: {output_dir}")
795
+
796
+ # Log summary
797
+ newly_evaluated = len(all_results) - len(skipped_models)
798
+ logger.info("\nβœ… Local evaluation complete!")
799
+ logger.info(f"πŸ“Š Newly evaluated: {newly_evaluated} models")
800
+ logger.info(f"⏭️ Skipped (already done): {len(skipped_models)} models")
801
+ logger.info(f"πŸ“ Total results: {len(all_results)} models")
802
+ logger.info(f"πŸ’Ύ Results available locally: {output_dir}")
803
+
804
+ if skipped_models:
805
+ logger.info(f"⏭️ Skipped models: {', '.join(skipped_models)}")
806
+
807
+ return all_results
808
+
809
+
810
+ def run_local_evaluation_simplified(
811
+ max_queries: int = 1000,
812
+ languages: list[str] | None = None,
813
+ output_dir: str = DEFAULT_OUTPUT_DIR,
814
+ ) -> list[dict[str, Any]]:
815
+ """Local evaluation function for simplified models only."""
816
+ logger.info("πŸ–₯️ Running simplified model evaluation locally")
817
+
818
+ # Discover simplified distillation models only
819
+ logger.info("πŸ” Discovering simplified distillation models...")
820
+ discovered_models = discover_simplified_models(".")
821
+
822
+ if not discovered_models:
823
+ logger.error("❌ No simplified distillation models found! Run 'distiller distill-simple' first.")
824
+ return []
825
+
826
+ logger.info(f"βœ… Found {len(discovered_models)} simplified models:")
827
+ for model_path in discovered_models:
828
+ logger.info(f" πŸ“ {model_path}")
829
+
830
+ return run_local_evaluation(
831
+ models=discovered_models,
832
+ max_queries=max_queries,
833
+ languages=languages,
834
+ output_dir=output_dir,
835
+ )
836
+
837
+
838
+ if __name__ == "__main__":
839
+ main()
src/distiller/patch_utils.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Patch utilities for applying fixes to installed packages.
3
+
4
+ This module provides functionality to automatically apply all patches
5
+ from the patches directory to fix bugs in third-party libraries.
6
+ """
7
+
8
+ import logging
9
+ import subprocess
10
+ import sys
11
+ from pathlib import Path
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ def find_patches_directory() -> Path:
17
+ """Find the patches directory relative to the current script location."""
18
+ # Go up from src/distiller/ to project root, then to patches/
19
+ current_file = Path(__file__)
20
+ project_root = current_file.parent.parent.parent # Go up 3 levels: distiller -> src -> project_root
21
+ patches_dir = project_root / "patches"
22
+
23
+ if not patches_dir.exists():
24
+ # Alternative: try relative to current working directory
25
+ patches_dir = Path("patches")
26
+
27
+ return patches_dir
28
+
29
+
30
+ def get_site_packages_path() -> Path:
31
+ """Get the site-packages directory path."""
32
+ import site
33
+
34
+ # Try to get the site-packages from the current environment
35
+ site_packages_dirs = site.getsitepackages()
36
+
37
+ # Prefer the first site-packages directory
38
+ if site_packages_dirs:
39
+ return Path(site_packages_dirs[0])
40
+
41
+ # Fallback: try to find it relative to Python executable
42
+ python_path = Path(sys.executable)
43
+ if python_path.name == "python" or python_path.name.startswith("python"):
44
+ # Standard virtual environment structure
45
+ venv_lib = python_path.parent.parent / "lib"
46
+ for item in venv_lib.iterdir():
47
+ if item.name.startswith("python"):
48
+ site_packages = item / "site-packages"
49
+ if site_packages.exists():
50
+ return site_packages
51
+
52
+ # Last resort: use current directory
53
+ return Path()
54
+
55
+
56
+ def apply_patch_file(patch_file: Path, target_dir: Path) -> bool:
57
+ """
58
+ Apply a single patch file to the target directory.
59
+
60
+ Args:
61
+ patch_file: Path to the .patch file
62
+ target_dir: Target directory (usually site-packages)
63
+
64
+ Returns:
65
+ True if patch was applied successfully, False otherwise
66
+ """
67
+ try:
68
+ logger.info(f"Applying patch: {patch_file.name}")
69
+
70
+ # Use patch command with the following options:
71
+ # -p1: strip 1 leading directory from paths
72
+ # -d: change to directory before applying
73
+ # -f: force (don't ask questions)
74
+ # -N: don't reverse patches that appear to be already applied
75
+ result = subprocess.run( # noqa: S603
76
+ ["patch", "-p1", "-d", str(target_dir), "-f", "-N"], # noqa: S607
77
+ input=patch_file.read_text(),
78
+ text=True,
79
+ capture_output=True,
80
+ check=False, # Don't raise exception on non-zero exit
81
+ )
82
+
83
+ if result.returncode == 0:
84
+ logger.info(f"Successfully applied patch: {patch_file.name}")
85
+ return True
86
+ if "already applied" in result.stderr.lower() or "reversed" in result.stderr.lower():
87
+ logger.info(f"Patch {patch_file.name} already applied")
88
+ return True
89
+ logger.warning(f"Failed to apply patch {patch_file.name}: {result.stderr}")
90
+ return False
91
+
92
+ except FileNotFoundError:
93
+ logger.exception("'patch' command not found. Please install patch utility.")
94
+ return False
95
+ except Exception:
96
+ logger.exception(f"Error applying patch {patch_file.name}")
97
+ return False
98
+
99
+
100
+ def apply_all_patches() -> int:
101
+ """
102
+ Apply all patches from the patches directory.
103
+
104
+ Returns:
105
+ Number of patches successfully applied
106
+ """
107
+ patches_dir = find_patches_directory()
108
+
109
+ if not patches_dir.exists():
110
+ logger.warning(f"Patches directory not found: {patches_dir}")
111
+ return 0
112
+
113
+ # Find all .patch files
114
+ patch_files = list(patches_dir.glob("*.patch"))
115
+
116
+ if not patch_files:
117
+ logger.info("No patch files found")
118
+ return 0
119
+
120
+ # Get target directory (site-packages)
121
+ target_dir = get_site_packages_path()
122
+ logger.info(f"Applying patches to: {target_dir}")
123
+
124
+ success_count = 0
125
+
126
+ # Sort patch files for consistent ordering
127
+ for patch_file in sorted(patch_files):
128
+ if apply_patch_file(patch_file, target_dir):
129
+ success_count += 1
130
+
131
+ logger.info(f"Applied {success_count}/{len(patch_files)} patches successfully")
132
+ return success_count
133
+
134
+
135
+ def main() -> None:
136
+ """Main function for standalone execution."""
137
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
138
+
139
+ print("Applying all patches...")
140
+ success_count = apply_all_patches()
141
+ print(f"Done. Applied {success_count} patches.")
142
+
143
+
144
+ if __name__ == "__main__":
145
+ main()
src/distiller/sync.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Sync utility for downloading files from Beam volume to local directory.
3
+
4
+ This module provides functionality to download generated files from the Beam volume
5
+ back to the local filesystem, including:
6
+ - Final distilled model files (model.safetensors, tokenizer.json, etc.)
7
+ - Analysis reports and charts (README.md, comparison charts, etc.)
8
+ """
9
+
10
+ import logging
11
+ import shutil
12
+ from pathlib import Path
13
+
14
+ from .beam_utils import create_beam_utilities
15
+
16
+ # Configure logging
17
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
18
+ logger = logging.getLogger(__name__)
19
+
20
+ # Beam volume configuration (must match distill.py)
21
+ VOLUME_NAME = "gte_qwen2_m2v_code"
22
+ VOLUME_PATH = "./gte_qwen2_m2v_code"
23
+
24
+ # Model files to sync
25
+ MODEL_FILES = [
26
+ "model.safetensors",
27
+ "tokenizer.json",
28
+ "modules.json",
29
+ "config.json",
30
+ "pytorch_model.bin", # Backup format
31
+ "vocab.txt", # If present
32
+ ]
33
+
34
+ # Analysis directories and files
35
+ ANALYSIS_DIRS = [
36
+ "analysis_results/reports",
37
+ "analysis_results/charts",
38
+ "evaluation_results",
39
+ ]
40
+
41
+ ANALYSIS_FILES = [
42
+ "analysis_results/reports/analysis_report.md",
43
+ "analysis_results/reports/README.md",
44
+ "analysis_results/charts/*.png",
45
+ "analysis_results/charts/*.html",
46
+ "evaluation_results/*.json",
47
+ "evaluation_results/comparisons/*.csv",
48
+ ]
49
+
50
+
51
+ def sync_model_files(output_dir: str) -> bool:
52
+ """Download final model files from Beam volume."""
53
+ logger.info("πŸ”„ Syncing model files from Beam volume...")
54
+
55
+ output_path = Path(output_dir)
56
+ output_path.mkdir(parents=True, exist_ok=True)
57
+
58
+ # First, let's debug what's actually in the volume
59
+ volume_root = Path(VOLUME_PATH)
60
+ logger.info(f"πŸ” Debugging volume contents at: {volume_root}")
61
+
62
+ if volume_root.exists():
63
+ logger.info("πŸ“ Volume root directory contents:")
64
+ for item in volume_root.iterdir():
65
+ if item.is_file():
66
+ logger.info(f" πŸ“„ {item.name} ({item.stat().st_size} bytes)")
67
+ elif item.is_dir():
68
+ logger.info(f" πŸ“ {item.name}/ (directory)")
69
+ # List files in important subdirectories
70
+ if item.name in ["models", "checkpoints", "gte_qwen2_m2v_code"]:
71
+ try:
72
+ logger.info(f" Contents of {item.name}/:")
73
+ for subitem in item.iterdir():
74
+ if subitem.is_file():
75
+ logger.info(f" πŸ“„ {subitem.name} ({subitem.stat().st_size} bytes)")
76
+ else:
77
+ logger.info(f" πŸ“ {subitem.name}/")
78
+ # Check one level deeper for model files
79
+ if subitem.is_dir():
80
+ for subsubitem in subitem.iterdir():
81
+ if subsubitem.is_file() and subsubitem.name in MODEL_FILES:
82
+ logger.info(f" 🎯 FOUND MODEL FILE: {subsubitem}")
83
+ except Exception as e:
84
+ logger.warning(f" Error exploring {item.name}: {e}")
85
+
86
+ # Also check for model files directly in root
87
+ logger.info("πŸ” Checking for model files directly in volume root:")
88
+ for model_file in MODEL_FILES:
89
+ root_file = volume_root / model_file
90
+ if root_file.exists():
91
+ logger.info(f" 🎯 FOUND: {model_file} in root ({root_file.stat().st_size} bytes)")
92
+ else:
93
+ logger.error(f"❌ Volume root does not exist: {volume_root}")
94
+ return False
95
+
96
+ # Since training completed successfully, look for model files in all possible locations
97
+ model_locations = [
98
+ Path(VOLUME_PATH), # Root of volume (where final model was saved)
99
+ Path(VOLUME_PATH) / "models" / "refined_model", # Refined model directory
100
+ ]
101
+
102
+ synced_files = []
103
+
104
+ for location in model_locations:
105
+ logger.info(f"πŸ“‚ Checking model location: {location}")
106
+
107
+ if not location.exists():
108
+ logger.info(f" ⚠️ Location does not exist: {location}")
109
+ continue
110
+
111
+ # Try to download each model file directly
112
+ for model_file in MODEL_FILES:
113
+ source_path = location / model_file
114
+ dest_path = output_path / model_file
115
+
116
+ if source_path.exists():
117
+ try:
118
+ shutil.copy2(source_path, dest_path)
119
+ synced_files.append(model_file)
120
+ logger.info(f"βœ… Downloaded: {model_file}")
121
+ except Exception as e:
122
+ logger.warning(f"⚠️ Failed to copy {model_file}: {e}")
123
+
124
+ if synced_files:
125
+ logger.info(f"πŸŽ‰ Successfully synced {len(synced_files)} model files:")
126
+ for file in synced_files:
127
+ logger.info(f" βœ“ {file}")
128
+ return True
129
+ logger.error("❌ No model files found to sync")
130
+ return False
131
+
132
+
133
+ def sync_analysis_files(output_dir: str) -> bool:
134
+ """Download analysis reports and charts from Beam volume."""
135
+ logger.info("πŸ”„ Syncing analysis files from Beam volume...")
136
+
137
+ output_path = Path(output_dir)
138
+ output_path.mkdir(parents=True, exist_ok=True)
139
+
140
+ synced_files = []
141
+
142
+ # Sync analysis reports (including README.md)
143
+ analysis_reports_dir = Path(VOLUME_PATH) / "analysis_results" / "reports"
144
+ if analysis_reports_dir.exists():
145
+ for report_file in analysis_reports_dir.glob("*.md"):
146
+ dest_path = output_path / report_file.name
147
+ try:
148
+ shutil.copy2(report_file, dest_path)
149
+ synced_files.append(str(report_file.name))
150
+ logger.info(f"βœ… Downloaded report: {report_file.name}")
151
+
152
+ # Special handling for README.md - copy to root
153
+ if report_file.name in {"analysis_report.md", "README.md"}:
154
+ root_readme = Path(output_dir) / "README.md"
155
+ shutil.copy2(report_file, root_readme)
156
+ logger.info("βœ… Updated root README.md")
157
+
158
+ except Exception as e:
159
+ logger.warning(f"⚠️ Failed to copy {report_file.name}: {e}")
160
+
161
+ # Sync charts
162
+ charts_dir = Path(VOLUME_PATH) / "analysis_results" / "charts"
163
+ local_charts_dir = output_path / "charts"
164
+ if charts_dir.exists():
165
+ local_charts_dir.mkdir(exist_ok=True)
166
+
167
+ for chart_file in charts_dir.glob("*"):
168
+ if chart_file.is_file():
169
+ dest_path = local_charts_dir / chart_file.name
170
+ try:
171
+ shutil.copy2(chart_file, dest_path)
172
+ synced_files.append(f"charts/{chart_file.name}")
173
+ logger.info(f"βœ… Downloaded chart: {chart_file.name}")
174
+ except Exception as e:
175
+ logger.warning(f"⚠️ Failed to copy chart {chart_file.name}: {e}")
176
+
177
+ # Sync evaluation results
178
+ eval_dir = Path(VOLUME_PATH) / "evaluation_results"
179
+ local_eval_dir = output_path / "evaluation_results"
180
+ if eval_dir.exists():
181
+ local_eval_dir.mkdir(exist_ok=True)
182
+
183
+ for eval_file in eval_dir.glob("*.json"):
184
+ dest_path = local_eval_dir / eval_file.name
185
+ try:
186
+ shutil.copy2(eval_file, dest_path)
187
+ synced_files.append(f"evaluation_results/{eval_file.name}")
188
+ logger.info(f"βœ… Downloaded evaluation: {eval_file.name}")
189
+ except Exception as e:
190
+ logger.warning(f"⚠️ Failed to copy evaluation {eval_file.name}: {e}")
191
+
192
+ if synced_files:
193
+ logger.info(f"πŸŽ‰ Successfully synced {len(synced_files)} analysis files:")
194
+ for file in synced_files[:10]: # Show first 10
195
+ logger.info(f" βœ“ {file}")
196
+ if len(synced_files) > 10:
197
+ logger.info(f" ... and {len(synced_files) - 10} more files")
198
+ return True
199
+ logger.error("❌ No analysis files found to sync")
200
+ return False
201
+
202
+
203
+ def sync_files(
204
+ model_files: bool = False,
205
+ analysis_files: bool = False,
206
+ all_files: bool = False,
207
+ output_dir: str = ".",
208
+ ) -> None:
209
+ """Main sync function to download files from Beam volume."""
210
+ logger.info("πŸš€ Starting file sync from Beam volume")
211
+ logger.info(f"πŸ“ Local output directory: {output_dir}")
212
+
213
+ # Initialize Beam utilities (read-only)
214
+ try:
215
+ volume_mgr, checkpoint_mgr, model_mgr, eval_mgr = create_beam_utilities(VOLUME_NAME, VOLUME_PATH)
216
+ logger.info(f"βœ… Connected to Beam volume: {VOLUME_NAME}")
217
+ except Exception:
218
+ logger.exception("❌ Failed to connect to Beam volume")
219
+ logger.info("Make sure you have run the distillation/evaluation on Beam first")
220
+ return
221
+
222
+ # Check what files to sync
223
+ sync_model = model_files or all_files
224
+ sync_analysis = analysis_files or all_files
225
+
226
+ if not (sync_model or sync_analysis):
227
+ logger.error("❌ No file types specified. Use --model-files, --analysis-files, or --all")
228
+ return
229
+
230
+ success_count = 0
231
+
232
+ # Sync model files
233
+ if sync_model:
234
+ logger.info("\n" + "=" * 60) # noqa: G003
235
+ logger.info("MODEL FILES SYNC")
236
+ logger.info("=" * 60)
237
+ if sync_model_files(output_dir):
238
+ success_count += 1
239
+
240
+ # Sync analysis files
241
+ if sync_analysis:
242
+ logger.info("\n" + "=" * 60) # noqa: G003
243
+ logger.info("ANALYSIS FILES SYNC")
244
+ logger.info("=" * 60)
245
+ if sync_analysis_files(output_dir):
246
+ success_count += 1
247
+
248
+ # Summary
249
+ logger.info("\n" + "=" * 60) # noqa: G003
250
+ logger.info("SYNC SUMMARY")
251
+ logger.info("=" * 60)
252
+
253
+ total_requested = sum([sync_model, sync_analysis])
254
+
255
+ if success_count == total_requested:
256
+ logger.info("πŸŽ‰ All requested files synced successfully!")
257
+ elif success_count > 0:
258
+ logger.info(f"⚠️ Partial sync: {success_count}/{total_requested} file types synced")
259
+ else:
260
+ logger.error("❌ No files were synced")
261
+
262
+ logger.info(f"πŸ“‚ Files saved to: {Path(output_dir).absolute()}")