Spaces:
Running
Running
File size: 10,609 Bytes
4465cb6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 | """Load models once and run four-way Java summarization comparisons."""
from __future__ import annotations
import logging
import pickle
import time
from dataclasses import dataclass
from pathlib import Path
from datasets import concatenate_datasets, load_dataset
from engine.models import CodeT5Model, LexRankModel, SentenceTransformerModel, TFIDFModel
from engine.preprocessing import split_code_statements, split_java_methods
logger = logging.getLogger(__name__)
CACHE_DIR = Path(__file__).resolve().parent.parent / "cache"
IDF_CACHE = CACHE_DIR / "idf_weights_train_val.pkl"
DATASET = "google/code_x_glue_ct_code_to_text"
FIT_SPLITS = ("train", "validation")
@dataclass
class MethodSummary:
name: str
summary: str
@dataclass
class ModelSummary:
model_id: str
model: str
tier: str
approach: str
accent: str
summary: str
elapsed_ms: float
methods: list[MethodSummary]
@dataclass
class ComparisonResult:
filename: str
char_count: int
token_count: int
statement_count: int
method_count: int
top_n: int
summaries: list[ModelSummary]
total_elapsed_ms: float
MODEL_CATALOG = [
{
"id": "tfidf",
"name": "TF-IDF",
"glyph": "TF",
"accent": "#2dd4bf",
"family": "Statistical · Bag-of-words",
"tier": "Corpus-fitted extractive",
"approach": "Extractive",
"input": "Statement fragments",
"checkpoint": None,
"description": "Scores each code statement by TF-IDF using IDF weights fitted on the CodeXGLUE Java train + validation corpus.",
"tagline": "Picks statements packed with rare, high-signal terms.",
"steps": [
"Fit inverse-document-frequency (IDF) weights over the Java train + validation corpus.",
"Tokenize each statement: split identifiers, lowercase, drop stopwords.",
"Score every statement as the sum of term-frequency x IDF.",
"Return the top-N highest-scoring statements as the summary.",
],
"strengths": ["Fast and fully offline", "Interpretable scores", "No GPU required"],
"limitations": ["Output is code-like, not prose", "Ignores word order and context"],
"speed": "Instant",
},
{
"id": "lexrank",
"name": "LexRank",
"glyph": "LR",
"accent": "#38bdf8",
"family": "Graph · Centrality",
"tier": "Corpus-fitted extractive",
"approach": "Extractive",
"input": "Statement fragments",
"checkpoint": None,
"description": "Builds a similarity graph over statements and runs PageRank to pick the most central fragments.",
"tagline": "Selects statements most representative of the whole file.",
"steps": [
"Build a TF-IDF vector for each statement using shared corpus IDF.",
"Compute pairwise cosine similarity to form a statement graph.",
"Threshold weak edges, then run PageRank over the graph.",
"Return the most central (highest-ranked) statements.",
],
"strengths": ["Captures redundancy / centrality", "Offline and interpretable", "Robust on longer files"],
"limitations": ["Needs several statements to rank", "Still extractive, not generative"],
"speed": "Fast",
},
{
"id": "sentence_transformers",
"name": "SentenceTransformers",
"glyph": "ST",
"accent": "#a78bfa",
"family": "Neural · Sentence embeddings",
"tier": "General-language pretrained",
"approach": "Extractive",
"input": "Statement fragments",
"checkpoint": "sentence-transformers/all-MiniLM-L6-v2",
"description": "Encodes statements with all-MiniLM-L6-v2 and selects those closest to the centroid embedding.",
"tagline": "Uses semantic meaning to find the most central statements.",
"steps": [
"Embed each statement with the all-MiniLM-L6-v2 transformer.",
"Average the embeddings into a single centroid vector.",
"Rank statements by cosine similarity to the centroid.",
"Return the statements closest to the semantic center.",
],
"strengths": ["Understands English semantics", "Order-aware encoder", "No corpus fitting needed"],
"limitations": ["Pretrained on prose, not code", "Heavier than TF-IDF/LexRank"],
"speed": "Moderate",
},
{
"id": "codet5",
"name": "CodeT5",
"glyph": "T5",
"accent": "#f59e0b",
"family": "Transformer · Seq2seq",
"tier": "Code-specific fine-tuned",
"approach": "Abstractive",
"input": "Per-method Java source (256-token window each)",
"checkpoint": "Salesforce/codet5-base-codexglue-sum-java",
"description": "Generates natural-language summaries from raw Java source using a CodeT5 checkpoint fine-tuned on CodeXGLUE.",
"tagline": "Writes a fresh English sentence describing the code.",
"steps": [
"Split the file into individual Java methods.",
"Byte-level BPE tokenize each method (first 256 tokens).",
"Decode with beam search — one English sentence per method, same as evaluation.",
"Show each method summary separately in the results view.",
],
"strengths": ["True natural-language output", "Fine-tuned on Java code-comment pairs", "Best quality summaries"],
"limitations": ["Slow on CPU", "256-token input limit", "Can hallucinate details"],
"speed": "Slowest",
},
]
class SummarizationPipeline:
def __init__(self, top_n: int = 5) -> None:
self.top_n = top_n
self.ready = False
self.loading = False
self.error: str | None = None
self.tfidf: TFIDFModel | None = None
self.lexrank: LexRankModel | None = None
self.st_model: SentenceTransformerModel | None = None
self.codet5: CodeT5Model | None = None
def load(self) -> None:
if self.ready or self.loading:
return
self.loading = True
self.error = None
try:
idf, n = self._load_or_build_idf()
logger.info("Fitting TF-IDF and LexRank from cached IDF (%d terms)", len(idf))
self.tfidf = TFIDFModel().load_idf(idf, n)
self.lexrank = LexRankModel().load_idf(idf)
logger.info("Loading SentenceTransformers ...")
self.st_model = SentenceTransformerModel()
logger.info("Loading CodeT5 ...")
self.codet5 = CodeT5Model()
self.ready = True
logger.info("Pipeline ready.")
except Exception as exc:
self.error = str(exc)
logger.exception("Failed to load summarization pipeline")
raise
finally:
self.loading = False
def _load_or_build_idf(self) -> tuple[dict[str, float], int]:
CACHE_DIR.mkdir(parents=True, exist_ok=True)
if IDF_CACHE.exists():
logger.info("Loading IDF cache from %s", IDF_CACHE)
with IDF_CACHE.open("rb") as f:
payload = pickle.load(f)
return payload["idf"], payload["N"]
logger.info("Building IDF from %s (%s) ...", DATASET, " + ".join(FIT_SPLITS))
dataset = load_dataset(DATASET, "java")
fit_data = concatenate_datasets([dataset[split] for split in FIT_SPLITS])
fit_corpus: list[str] = []
for row in fit_data:
fit_corpus.extend(split_code_statements(row["code"]))
tfidf = TFIDFModel().fit(fit_corpus)
with IDF_CACHE.open("wb") as f:
pickle.dump({"idf": tfidf.idf, "N": tfidf.N}, f)
logger.info(
"Cached IDF weights (%d terms, %d statements from %d methods)",
len(tfidf.idf),
len(fit_corpus),
len(fit_data),
)
return tfidf.idf, tfidf.N
def compare(self, java_source: str, filename: str = "upload.java") -> ComparisonResult:
if not self.ready:
raise RuntimeError("Pipeline is not ready. Call load() first.")
source = java_source.strip()
if not source:
raise ValueError("Java source is empty.")
statements = split_code_statements(source)
java_methods = split_java_methods(source)
summaries: list[ModelSummary] = []
t_total = time.perf_counter()
catalog_by_id = {m["id"]: m for m in MODEL_CATALOG}
extractive_runners = [
("tfidf", lambda: " ".join(self.tfidf.summarize(statements, self.top_n))),
("lexrank", lambda: " ".join(self.lexrank.summarize(statements, self.top_n))),
(
"sentence_transformers",
lambda: " ".join(self.st_model.summarize(statements, self.top_n)),
),
]
for model_id, run in extractive_runners:
meta = catalog_by_id[model_id]
t0 = time.perf_counter()
text = run()
summaries.append(ModelSummary(
model_id=model_id,
model=meta["name"],
tier=meta["tier"],
approach=meta["approach"],
accent=meta["accent"],
summary=text,
elapsed_ms=(time.perf_counter() - t0) * 1000,
methods=[],
))
codet5_meta = catalog_by_id["codet5"]
t0 = time.perf_counter()
codet5_methods: list[MethodSummary] = []
for method in java_methods:
codet5_methods.append(MethodSummary(
name=method["name"],
summary=self.codet5.summarize(method["code"]),
))
codet5_combined = "\n".join(
m.summary.strip() for m in codet5_methods if m.summary.strip()
)
summaries.append(ModelSummary(
model_id="codet5",
model=codet5_meta["name"],
tier=codet5_meta["tier"],
approach=codet5_meta["approach"],
accent=codet5_meta["accent"],
summary=codet5_combined,
elapsed_ms=(time.perf_counter() - t0) * 1000,
methods=codet5_methods,
))
return ComparisonResult(
filename=filename,
char_count=len(source),
token_count=len(source.split()),
statement_count=len(statements),
method_count=len(java_methods),
top_n=self.top_n,
summaries=summaries,
total_elapsed_ms=(time.perf_counter() - t_total) * 1000,
)
|