"""Load models once and run four-way Java summarization comparisons.""" from __future__ import annotations import logging import pickle import time from dataclasses import dataclass from pathlib import Path from datasets import concatenate_datasets, load_dataset from engine.models import CodeT5Model, LexRankModel, SentenceTransformerModel, TFIDFModel from engine.preprocessing import split_code_statements, split_java_methods logger = logging.getLogger(__name__) CACHE_DIR = Path(__file__).resolve().parent.parent / "cache" IDF_CACHE = CACHE_DIR / "idf_weights_train_val.pkl" DATASET = "google/code_x_glue_ct_code_to_text" FIT_SPLITS = ("train", "validation") @dataclass class MethodSummary: name: str summary: str @dataclass class ModelSummary: model_id: str model: str tier: str approach: str accent: str summary: str elapsed_ms: float methods: list[MethodSummary] @dataclass class ComparisonResult: filename: str char_count: int token_count: int statement_count: int method_count: int top_n: int summaries: list[ModelSummary] total_elapsed_ms: float MODEL_CATALOG = [ { "id": "tfidf", "name": "TF-IDF", "glyph": "TF", "accent": "#2dd4bf", "family": "Statistical · Bag-of-words", "tier": "Corpus-fitted extractive", "approach": "Extractive", "input": "Statement fragments", "checkpoint": None, "description": "Scores each code statement by TF-IDF using IDF weights fitted on the CodeXGLUE Java train + validation corpus.", "tagline": "Picks statements packed with rare, high-signal terms.", "steps": [ "Fit inverse-document-frequency (IDF) weights over the Java train + validation corpus.", "Tokenize each statement: split identifiers, lowercase, drop stopwords.", "Score every statement as the sum of term-frequency x IDF.", "Return the top-N highest-scoring statements as the summary.", ], "strengths": ["Fast and fully offline", "Interpretable scores", "No GPU required"], "limitations": ["Output is code-like, not prose", "Ignores word order and context"], "speed": "Instant", }, { "id": "lexrank", "name": "LexRank", "glyph": "LR", "accent": "#38bdf8", "family": "Graph · Centrality", "tier": "Corpus-fitted extractive", "approach": "Extractive", "input": "Statement fragments", "checkpoint": None, "description": "Builds a similarity graph over statements and runs PageRank to pick the most central fragments.", "tagline": "Selects statements most representative of the whole file.", "steps": [ "Build a TF-IDF vector for each statement using shared corpus IDF.", "Compute pairwise cosine similarity to form a statement graph.", "Threshold weak edges, then run PageRank over the graph.", "Return the most central (highest-ranked) statements.", ], "strengths": ["Captures redundancy / centrality", "Offline and interpretable", "Robust on longer files"], "limitations": ["Needs several statements to rank", "Still extractive, not generative"], "speed": "Fast", }, { "id": "sentence_transformers", "name": "SentenceTransformers", "glyph": "ST", "accent": "#a78bfa", "family": "Neural · Sentence embeddings", "tier": "General-language pretrained", "approach": "Extractive", "input": "Statement fragments", "checkpoint": "sentence-transformers/all-MiniLM-L6-v2", "description": "Encodes statements with all-MiniLM-L6-v2 and selects those closest to the centroid embedding.", "tagline": "Uses semantic meaning to find the most central statements.", "steps": [ "Embed each statement with the all-MiniLM-L6-v2 transformer.", "Average the embeddings into a single centroid vector.", "Rank statements by cosine similarity to the centroid.", "Return the statements closest to the semantic center.", ], "strengths": ["Understands English semantics", "Order-aware encoder", "No corpus fitting needed"], "limitations": ["Pretrained on prose, not code", "Heavier than TF-IDF/LexRank"], "speed": "Moderate", }, { "id": "codet5", "name": "CodeT5", "glyph": "T5", "accent": "#f59e0b", "family": "Transformer · Seq2seq", "tier": "Code-specific fine-tuned", "approach": "Abstractive", "input": "Per-method Java source (256-token window each)", "checkpoint": "Salesforce/codet5-base-codexglue-sum-java", "description": "Generates natural-language summaries from raw Java source using a CodeT5 checkpoint fine-tuned on CodeXGLUE.", "tagline": "Writes a fresh English sentence describing the code.", "steps": [ "Split the file into individual Java methods.", "Byte-level BPE tokenize each method (first 256 tokens).", "Decode with beam search — one English sentence per method, same as evaluation.", "Show each method summary separately in the results view.", ], "strengths": ["True natural-language output", "Fine-tuned on Java code-comment pairs", "Best quality summaries"], "limitations": ["Slow on CPU", "256-token input limit", "Can hallucinate details"], "speed": "Slowest", }, ] class SummarizationPipeline: def __init__(self, top_n: int = 5) -> None: self.top_n = top_n self.ready = False self.loading = False self.error: str | None = None self.tfidf: TFIDFModel | None = None self.lexrank: LexRankModel | None = None self.st_model: SentenceTransformerModel | None = None self.codet5: CodeT5Model | None = None def load(self) -> None: if self.ready or self.loading: return self.loading = True self.error = None try: idf, n = self._load_or_build_idf() logger.info("Fitting TF-IDF and LexRank from cached IDF (%d terms)", len(idf)) self.tfidf = TFIDFModel().load_idf(idf, n) self.lexrank = LexRankModel().load_idf(idf) logger.info("Loading SentenceTransformers ...") self.st_model = SentenceTransformerModel() logger.info("Loading CodeT5 ...") self.codet5 = CodeT5Model() self.ready = True logger.info("Pipeline ready.") except Exception as exc: self.error = str(exc) logger.exception("Failed to load summarization pipeline") raise finally: self.loading = False def _load_or_build_idf(self) -> tuple[dict[str, float], int]: CACHE_DIR.mkdir(parents=True, exist_ok=True) if IDF_CACHE.exists(): logger.info("Loading IDF cache from %s", IDF_CACHE) with IDF_CACHE.open("rb") as f: payload = pickle.load(f) return payload["idf"], payload["N"] logger.info("Building IDF from %s (%s) ...", DATASET, " + ".join(FIT_SPLITS)) dataset = load_dataset(DATASET, "java") fit_data = concatenate_datasets([dataset[split] for split in FIT_SPLITS]) fit_corpus: list[str] = [] for row in fit_data: fit_corpus.extend(split_code_statements(row["code"])) tfidf = TFIDFModel().fit(fit_corpus) with IDF_CACHE.open("wb") as f: pickle.dump({"idf": tfidf.idf, "N": tfidf.N}, f) logger.info( "Cached IDF weights (%d terms, %d statements from %d methods)", len(tfidf.idf), len(fit_corpus), len(fit_data), ) return tfidf.idf, tfidf.N def compare(self, java_source: str, filename: str = "upload.java") -> ComparisonResult: if not self.ready: raise RuntimeError("Pipeline is not ready. Call load() first.") source = java_source.strip() if not source: raise ValueError("Java source is empty.") statements = split_code_statements(source) java_methods = split_java_methods(source) summaries: list[ModelSummary] = [] t_total = time.perf_counter() catalog_by_id = {m["id"]: m for m in MODEL_CATALOG} extractive_runners = [ ("tfidf", lambda: " ".join(self.tfidf.summarize(statements, self.top_n))), ("lexrank", lambda: " ".join(self.lexrank.summarize(statements, self.top_n))), ( "sentence_transformers", lambda: " ".join(self.st_model.summarize(statements, self.top_n)), ), ] for model_id, run in extractive_runners: meta = catalog_by_id[model_id] t0 = time.perf_counter() text = run() summaries.append(ModelSummary( model_id=model_id, model=meta["name"], tier=meta["tier"], approach=meta["approach"], accent=meta["accent"], summary=text, elapsed_ms=(time.perf_counter() - t0) * 1000, methods=[], )) codet5_meta = catalog_by_id["codet5"] t0 = time.perf_counter() codet5_methods: list[MethodSummary] = [] for method in java_methods: codet5_methods.append(MethodSummary( name=method["name"], summary=self.codet5.summarize(method["code"]), )) codet5_combined = "\n".join( m.summary.strip() for m in codet5_methods if m.summary.strip() ) summaries.append(ModelSummary( model_id="codet5", model=codet5_meta["name"], tier=codet5_meta["tier"], approach=codet5_meta["approach"], accent=codet5_meta["accent"], summary=codet5_combined, elapsed_ms=(time.perf_counter() - t0) * 1000, methods=codet5_methods, )) return ComparisonResult( filename=filename, char_count=len(source), token_count=len(source.split()), statement_count=len(statements), method_count=len(java_methods), top_n=self.top_n, summaries=summaries, total_elapsed_ms=(time.perf_counter() - t_total) * 1000, )