Sky-Blue-da-ba-dee commited on
Commit
ac9ddbb
·
1 Parent(s): 629e980

added files

Browse files
Dockerfile ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11
2
+
3
+ # User
4
+ RUN useradd -m -u 1000 user
5
+ USER user
6
+ ENV PATH="/home/user/.local/bin:$PATH"
7
+
8
+ WORKDIR /app
9
+
10
+ COPY --chown=user requirements.txt requirements.txt
11
+
12
+ RUN grep -v "codecommentclassification" requirements.txt > requirements-docker.txt \
13
+ && pip install --no-cache-dir --upgrade -r requirements-docker.txt
14
+
15
+ COPY --chown=user api /app/api
16
+ COPY --chown=user codecommentclassification /app/codecommentclassification
17
+
18
+ COPY --chown=user models/model_cards /app/models/model_cards
19
+
20
+ RUN mkdir -p /app/models/api
21
+
22
+ ENV PYTHONPATH=/app
23
+ ENV MODELS_DIR=/app/models/api
24
+
25
+ EXPOSE 7860
26
+
27
+ CMD ["uvicorn", "api.main:app", "--host", "0.0.0.0", "--port", "7860"]
api/__pycache__/main.cpython-311.pyc ADDED
Binary file (9.67 kB). View file
 
api/__pycache__/schemas.cpython-311.pyc ADDED
Binary file (2.73 kB). View file
 
api/__pycache__/sync_models.cpython-311.pyc ADDED
Binary file (7.42 kB). View file
 
api/main.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Main API for Code Comment Classification using FastAPI."""
2
+ from contextlib import asynccontextmanager
3
+ from datetime import datetime
4
+ from functools import lru_cache, wraps
5
+ from http import HTTPStatus
6
+ import inspect
7
+ import logging
8
+ import os
9
+ from pathlib import Path
10
+
11
+ from api.schemas import PredictRequest
12
+ from api.sync_models import sync_best_models_to_disk
13
+ from fastapi import FastAPI, Request, Response
14
+ from fastapi.middleware.cors import CORSMiddleware
15
+ from fastapi.responses import JSONResponse
16
+
17
+ from codecommentclassification import ModelPredictor
18
+
19
+ MODELS_DIR = Path(os.getenv("MODELS_DIR", "models/api"))
20
+
21
+
22
+ logging.basicConfig(
23
+ level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
24
+ )
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ @lru_cache(maxsize=3)
29
+ def get_predictor(lang: str, model_type: str) -> ModelPredictor:
30
+ """Lazily loads the heavy model only when requested."""
31
+ logger.info(f"Loading model for {lang} - {model_type}...")
32
+ return ModelPredictor(lang=lang, model_type=model_type, model_root=str(MODELS_DIR))
33
+
34
+
35
+ @asynccontextmanager
36
+ async def lifespan(app: FastAPI):
37
+ """Lifespan context manager to sync models at startup."""
38
+ try:
39
+ logger.info(f"Syncing champion models from MLflow to {MODELS_DIR}...")
40
+ sync_best_models_to_disk(
41
+ models_root=MODELS_DIR.parent,
42
+ api_subdir=MODELS_DIR.name,
43
+ )
44
+ except Exception as e:
45
+ logger.error(f"Failed to sync models at startup: {e}")
46
+
47
+ if not MODELS_DIR.exists():
48
+ logger.warning(f"Models directory not found at: {MODELS_DIR.resolve()}")
49
+ else:
50
+ logger.info(f"Using models from: {MODELS_DIR.resolve()}")
51
+ yield
52
+ get_predictor.cache_clear()
53
+
54
+
55
+ app = FastAPI(
56
+ title="Code Comment Classification API",
57
+ description="API for classifying code comments using SetFit models.",
58
+ version="0.1",
59
+ lifespan=lifespan,
60
+ )
61
+
62
+ frontend_origins = os.getenv("FRONTEND_ORIGINS")
63
+
64
+ if frontend_origins:
65
+ origins = [o.strip() for o in frontend_origins.split(",") if o.strip()]
66
+ else:
67
+ # default di sviluppo
68
+ origins = [
69
+ "http://localhost:5173",
70
+ "http://127.0.0.1:5173",
71
+ "http://localhost",
72
+ ]
73
+
74
+ app.add_middleware(
75
+ CORSMiddleware,
76
+ allow_origins=origins,
77
+ allow_credentials=True,
78
+ allow_methods=["*"],
79
+ allow_headers=["*"],
80
+ )
81
+
82
+
83
+ def _build_response(results: dict, request: Request):
84
+ if isinstance(results, (Response, JSONResponse)):
85
+ return results
86
+
87
+ response = {
88
+ "message": results["message"],
89
+ "method": request.method,
90
+ "status-code": results["status-code"],
91
+ "timestamp": datetime.now().isoformat(),
92
+ "url": request.url._url,
93
+ }
94
+
95
+ if "data" in results:
96
+ response["data"] = results["data"]
97
+
98
+ return response
99
+
100
+
101
+ def construct_response(f):
102
+ """Construct a JSON response for an endpoint's results (sync and async)."""
103
+ if inspect.iscoroutinefunction(f):
104
+
105
+ @wraps(f)
106
+ async def wrap(request: Request, *args, **kwargs):
107
+ results = await f(request, *args, **kwargs)
108
+ return _build_response(results, request)
109
+ else:
110
+
111
+ @wraps(f)
112
+ def wrap(request: Request, *args, **kwargs):
113
+ results = f(request, *args, **kwargs)
114
+ return _build_response(results, request)
115
+
116
+ return wrap
117
+
118
+
119
+ @app.get("/", tags=["General"])
120
+ @construct_response
121
+ def _index(request: Request):
122
+ """Root endpoint."""
123
+ return {
124
+ "message": HTTPStatus.OK.phrase,
125
+ "status-code": HTTPStatus.OK,
126
+ "data": {
127
+ "message": "Welcome to the Code Comment Classification API! Please use /docs for API documentation."
128
+ },
129
+ }
130
+
131
+
132
+ @app.get("/privacy", tags=["General"])
133
+ @construct_response
134
+ async def get_privacy_notice(request: Request):
135
+ """Return the Privacy Notice for the API."""
136
+ return {
137
+ "message": "Privacy Notice",
138
+ "status-code": HTTPStatus.OK,
139
+ "data": {
140
+ "policy": "This API processes text data for classification purposes only. No data is permanently stored.",
141
+ "compliance_link": "https://behavizapi.peopleware.ai/api/docs#section/Getting-Started/Privacy-Notice",
142
+ },
143
+ }
144
+
145
+
146
+ @app.get("/status")
147
+ def get_status():
148
+ """Endpoint to check if the API is running."""
149
+ return {"status": "API is running"}
150
+
151
+
152
+ @app.get("/models", tags=["Prediction"])
153
+ @construct_response
154
+ def _get_models_list(request: Request):
155
+ """Return the list of available languages based on directories found in models/ ."""
156
+ # Since we aren't pre-loading, we scan the directory to see what IS available
157
+ if MODELS_DIR.exists():
158
+ available_languages = [
159
+ {"language": d.name, "model_types": mt.name}
160
+ for d in MODELS_DIR.iterdir()
161
+ if d.is_dir()
162
+ for mt in d.iterdir()
163
+ if mt.is_dir()
164
+ ]
165
+ else:
166
+ available_languages = []
167
+
168
+ return {
169
+ "message": HTTPStatus.OK.phrase,
170
+ "status-code": HTTPStatus.OK,
171
+ "data": available_languages,
172
+ }
173
+
174
+
175
+ @app.post("/predict", tags=["Prediction"])
176
+ @construct_response
177
+ def predict(
178
+ request: Request,
179
+ payload: PredictRequest,
180
+ ):
181
+ """Inference endpoint."""
182
+ if payload.model_type is None:
183
+ return {
184
+ "message": "Model type must be specified.",
185
+ "status-code": HTTPStatus.BAD_REQUEST,
186
+ }
187
+
188
+ try:
189
+ predictor = get_predictor(payload.language.value, payload.model_type.value)
190
+ result = predictor.predict(payload.text)
191
+ predictions_list = result.tolist() if hasattr(result, "tolist") else result
192
+
193
+ return {
194
+ "message": HTTPStatus.OK.phrase,
195
+ "status-code": HTTPStatus.OK,
196
+ "data": {
197
+ "language": payload.language,
198
+ "model_type": payload.model_type,
199
+ "predictions": predictions_list,
200
+ },
201
+ }
202
+
203
+ except FileNotFoundError:
204
+ return {
205
+ "message": f"Model for language '{payload.language}' not found.",
206
+ "status-code": HTTPStatus.NOT_FOUND,
207
+ }
208
+ except ValueError as e:
209
+ return {
210
+ "message": str(e),
211
+ "status-code": HTTPStatus.BAD_REQUEST,
212
+ }
213
+ except Exception as e:
214
+ return {
215
+ "message": f"Internal Error: {str(e)}",
216
+ "status-code": HTTPStatus.INTERNAL_SERVER_ERROR,
217
+ }
api/schemas.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """API Schemas for Predict Request and Response."""
2
+ from enum import Enum
3
+
4
+ from pydantic import BaseModel, ConfigDict, ValidationError
5
+
6
+
7
+ class ProgrammingLanguage(str, Enum):
8
+ """Programming languages supported for prediction."""
9
+
10
+ JAVA = "java"
11
+ PYTHON = "python"
12
+ PHARO = "pharo"
13
+
14
+
15
+ class ModelType(str, Enum):
16
+ """Model types for prediction."""
17
+
18
+ SETFIT = "setfit"
19
+ RANDOM_FOREST = "random_forest"
20
+ TRANSFORMER = "transformer"
21
+
22
+
23
+ class PredictRequest(BaseModel):
24
+ """Schema for Predict Request."""
25
+
26
+ text: str
27
+ language: ProgrammingLanguage
28
+ model_type: ModelType
29
+
30
+ model_config = ConfigDict(
31
+ json_schema_extra={
32
+ "example": {
33
+ "text": "This method calculates the average score.",
34
+ "language": "python",
35
+ "model_type": "transformer",
36
+ }
37
+ }
38
+ )
39
+
40
+
41
+ class PredictResponse(BaseModel):
42
+ """Schema for Predict Response."""
43
+
44
+ label: str
45
+ score: float
46
+
47
+
48
+ """ Demonstration of object instantiation, printing,
49
+ and validation error handling with dummy use cases"""
50
+ if __name__ == "__main__":
51
+ print("\n--- 1. Object Instantiation & Printing ---")
52
+
53
+ valid_data = {
54
+ "text": "This method calculates the average score.",
55
+ "language": "java",
56
+ "model_type": "setfit",
57
+ }
58
+
59
+ # Instantiate the object
60
+ request = PredictRequest(**valid_data)
61
+
62
+ # Print object as dictionary (.model_dump() is Pydantic V2 syntax)
63
+ print(f"Valid Request Object: {request.model_dump()}")
64
+
65
+ print("\n--- 2. Handling Invalid Data ---")
66
+
67
+ try:
68
+ print("Attempting to create request with language='c++'...")
69
+ # This should fail because 'c++' is not in ProgrammingLanguage Enum
70
+ invalid_request = PredictRequest(
71
+ text="std::cout << 'Hello';", language="c++", model_type="setfit"
72
+ )
73
+ except ValidationError as e:
74
+ print("SUCCESS: Validation Error Caught!")
75
+ print(e.json())
api/sync_models.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Synchronise champion MLflow models from the remote registry to the local filesystem."""
2
+
3
+ import logging
4
+ import os
5
+ from pathlib import Path
6
+ import shutil
7
+
8
+ import mlflow
9
+ from mlflow.tracking import MlflowClient
10
+
11
+ logger = logging.getLogger(__name__)
12
+ LANGUAGES = ("python", "java", "pharo")
13
+
14
+
15
+ def _get_mlflow_client() -> MlflowClient:
16
+ """Return an MLflow client configured from environment variables.
17
+
18
+ If ``MLFLOW_TRACKING_URI`` is defined, it is passed to
19
+ :func:`mlflow.set_tracking_uri`. Authentication (for example on DagsHub)
20
+ is handled by MLflow itself via the standard environment variables
21
+ ``MLFLOW_TRACKING_USERNAME`` and ``MLFLOW_TRACKING_PASSWORD``.
22
+ """
23
+ tracking_uri = os.getenv("MLFLOW_TRACKING_URI")
24
+ if tracking_uri:
25
+ mlflow.set_tracking_uri(tracking_uri)
26
+ return MlflowClient()
27
+
28
+
29
+ def _find_champion_version_for_language(
30
+ client: MlflowClient,
31
+ lang: str,
32
+ ):
33
+ """Return the champion model version for the given language, if any.
34
+
35
+ The function searches all registered models and looks for models whose name
36
+ starts with ``"<lang>-"`` (for example ``"python-transformer"``). For each
37
+ matching model it tries to resolve the alias ``"<lang>-champion"`` using
38
+ :meth:`MlflowClient.get_model_version_by_alias`.
39
+
40
+ Args:
41
+ client: Initialised MLflow client.
42
+ lang: Language identifier, such as ``"python"``, ``"java"`` or
43
+ ``"pharo"``.
44
+
45
+ Returns:
46
+ The matching :class:`mlflow.entities.model_registry.ModelVersion` if a
47
+ champion is found, otherwise ``None``.
48
+
49
+ """
50
+ alias_name = f"{lang}-champion"
51
+ prefix = f"{lang}-"
52
+
53
+ # Get all registered models and filter by language prefix.
54
+ for rm in client.search_registered_models():
55
+ model_name = rm.name
56
+ if not model_name.startswith(prefix):
57
+ continue
58
+
59
+ try:
60
+ mv = client.get_model_version_by_alias(
61
+ name=model_name,
62
+ alias=alias_name,
63
+ )
64
+ logger.info(
65
+ "Found champion model for %s: %s (version %s)",
66
+ lang,
67
+ model_name,
68
+ mv.version,
69
+ )
70
+ return mv
71
+ except Exception: # noqa: BLE001
72
+ logger.info("Alias not defined for model %s, trying next one.", model_name)
73
+ continue
74
+
75
+ logger.warning("No champion model found for %s.", lang)
76
+ return None
77
+
78
+
79
+ def sync_best_models_to_disk(
80
+ models_root: str | Path = "models",
81
+ api_subdir: str = "api",
82
+ ) -> None:
83
+ """Download champion models from MLflow and write them to disk.
84
+
85
+ For each language in :data:`LANGUAGES`, this function looks up the model
86
+ version with alias ``"<lang>-champion"`` and downloads its artifacts. After
87
+ download, the directory structure is normalised so that the final layout is:
88
+
89
+ .. code-block:: text
90
+
91
+ models/
92
+ <api_subdir>/
93
+ python/
94
+ <model_type>/
95
+ ...
96
+ java/
97
+ <model_type>/
98
+ ...
99
+ pharo/
100
+ <model_type>/
101
+ ...
102
+
103
+ For transformer models logged via ``mlflow.transformers``, the inner
104
+ ``model/`` directory is flattened so that the Hugging Face files
105
+ (``config.json``, ``model.safetensors``, ``tokenizer.json``, and so on)
106
+ live directly under ``<model_type>/``.
107
+
108
+ Args:
109
+ models_root: Base directory under which models are written. Can be a
110
+ string or :class:`pathlib.Path`. Defaults to ``"models"``.
111
+ api_subdir: Optional subdirectory appended under ``models_root`` (for
112
+ example ``"api"``). If empty, models are stored directly under
113
+ ``models_root``.
114
+
115
+ Raises:
116
+ OSError: If creating directories, moving files, or removing directories
117
+ fails at the OS level.
118
+
119
+ """
120
+ client = _get_mlflow_client()
121
+
122
+ root = Path(models_root)
123
+ if api_subdir:
124
+ root = root / api_subdir
125
+ root.mkdir(parents=True, exist_ok=True)
126
+ logger.info("Syncing best models to: %s", root.resolve())
127
+
128
+ for lang in LANGUAGES:
129
+ mv = _find_champion_version_for_language(client, lang)
130
+ if mv is None:
131
+ continue
132
+
133
+ model_name = mv.name
134
+ try:
135
+ lang_from_name, model_type = model_name.split("-", 1)
136
+ except ValueError:
137
+ logger.error("Unexpected model name format: %s", model_name)
138
+ continue
139
+
140
+ if lang_from_name != lang:
141
+ logger.warning(
142
+ "Language mismatch for model %s: expected %s, got %s",
143
+ model_name,
144
+ lang,
145
+ lang_from_name,
146
+ )
147
+
148
+ dest_dir = root / lang / model_type
149
+ if dest_dir.exists():
150
+ shutil.rmtree(dest_dir)
151
+ dest_dir.mkdir(parents=True, exist_ok=True)
152
+
153
+ logger.info(
154
+ "Downloading model '%s' version %s to %s...",
155
+ model_name,
156
+ mv.version,
157
+ dest_dir.resolve(),
158
+ )
159
+
160
+ try:
161
+ # Download the artifact (for example ".../java_transformer_model").
162
+ downloaded_path = Path(
163
+ mlflow.artifacts.download_artifacts(
164
+ artifact_uri=mv.source,
165
+ dst_path=str(dest_dir),
166
+ ),
167
+ )
168
+
169
+ # For transformer models logged with mlflow.transformers, artifacts
170
+ # are stored under an inner "model/" directory.
171
+ model_subdir = downloaded_path / "model"
172
+ if model_subdir.is_dir():
173
+ # Move the contents of "model" directly into dest_dir.
174
+ for item in model_subdir.iterdir():
175
+ shutil.move(str(item), dest_dir / item.name)
176
+
177
+ # Remove the wrapper directory (with MLmodel, conda.yaml, etc.).
178
+ if downloaded_path != dest_dir:
179
+ shutil.rmtree(downloaded_path)
180
+
181
+ except Exception as e:
182
+ logger.error(
183
+ "Failed to download/reshape model '%s' version %s: %s",
184
+ model_name,
185
+ mv.version,
186
+ e,
187
+ )
188
+
189
+
190
+ if __name__ == "__main__":
191
+ logging.basicConfig(level=logging.INFO)
192
+ sync_best_models_to_disk()
codecommentclassification/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """CodeCommentClassification package initialization."""
2
+
3
+ from .predictor import ModelPredictor
4
+
5
+ __all__ = ["ModelPredictor"]
codecommentclassification/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (260 Bytes). View file
 
codecommentclassification/__pycache__/predictor.cpython-311.pyc ADDED
Binary file (7.51 kB). View file
 
codecommentclassification/modeling/__pycache__/evaluate_models.cpython-311.pyc ADDED
Binary file (12 kB). View file
 
codecommentclassification/modeling/__pycache__/train.cpython-311.pyc ADDED
Binary file (9.88 kB). View file
 
codecommentclassification/modeling/__pycache__/utils.cpython-311.pyc ADDED
Binary file (4.22 kB). View file
 
codecommentclassification/modeling/evaluate_models.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module for evaluating models on test set."""
2
+
3
+ import argparse
4
+ import json
5
+ import os
6
+ import time
7
+
8
+ import dagshub
9
+ import joblib
10
+ import mlflow
11
+ import numpy as np
12
+ import pandas as pd
13
+ from setfit import SetFitModel
14
+ import torch
15
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
16
+
17
+ from .utils import load_dataset_splits, parse_labels_column
18
+
19
+ LABELS = {
20
+ "java": ["summary", "Ownership", "Expand", "usage", "Pointer", "deprecation", "rational"],
21
+ "python": ["Usage", "Parameters", "DevelopmentNotes", "Expand", "Summary"],
22
+ "pharo": [
23
+ "Keyimplementationpoints",
24
+ "Example",
25
+ "Responsibilities",
26
+ "Intent",
27
+ "Keymessages",
28
+ "Collaborators",
29
+ ],
30
+ }
31
+
32
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33
+
34
+ dagshub.init(repo_owner="se4ai2526-uniba", repo_name="TheClouds", mlflow=True)
35
+
36
+
37
+ def evaluate_and_benchmark(lang, model_type, model_path, data_path, metrics_output_path):
38
+ """Load a trained model, run detailed benchmarking for performance and metrics,
39
+ and log the results to a new MLflow run.
40
+
41
+ """
42
+ mlflow.set_experiment("Model Benchmarking")
43
+ print(f"Starting Evaluation & Benchmarking for language: {lang} and model: {model_type}")
44
+
45
+ with mlflow.start_run(run_name=f"evaluation_local_{lang}_{model_type}"):
46
+ mlflow.log_param("language", lang)
47
+ mlflow.log_param("model_type", model_type)
48
+ mlflow.log_param("model_path", model_path)
49
+ mlflow.log_param("data_path", data_path)
50
+
51
+ avg_runtime_sec = 0.0
52
+ avg_gflops = 0.0
53
+
54
+ # -----------------------
55
+ # SETFIT
56
+ # -----------------------
57
+ if model_type == "setfit":
58
+ ds = load_dataset_splits(base_dir=data_path, langs=[lang])
59
+ eval_df = parse_labels_column(ds[f"{lang}_test"])
60
+
61
+ x_eval = eval_df["combo"].astype(str).tolist()
62
+ y_true = np.array(eval_df["labels"].tolist(), dtype=int)
63
+
64
+ model = SetFitModel.from_pretrained(model_path)
65
+
66
+ with torch.profiler.profile(with_flops=True) as p:
67
+ begin = time.time()
68
+ for _ in range(10):
69
+ y_pred = model(x_eval)
70
+ total_runtime = time.time() - begin
71
+
72
+ avg_runtime_sec = total_runtime / 10
73
+ avg_gflops = (sum(k.flops for k in p.key_averages()) / 1e9) / 10
74
+
75
+ y_pred = np.array(y_pred)
76
+
77
+ # -----------------------
78
+ # RANDOM FOREST
79
+ # -----------------------
80
+ elif model_type == "random_forest":
81
+ ds = load_dataset_splits(base_dir=data_path, langs=[lang])
82
+ eval_df = parse_labels_column(ds[f"{lang}_test"])
83
+
84
+ x_eval = eval_df["combo"].astype(str).tolist()
85
+ y_true = np.array(eval_df["labels"].tolist(), dtype=int)
86
+
87
+ model = joblib.load(f"{model_path}.joblib")
88
+
89
+ begin = time.time()
90
+ for _ in range(10):
91
+ y_pred = model.predict(x_eval)
92
+ total_runtime = time.time() - begin
93
+
94
+ avg_runtime_sec = total_runtime / 10
95
+ avg_gflops = 0.0 # not applicable
96
+
97
+ y_pred = np.array(y_pred)
98
+
99
+ # -----------------------
100
+ # TRANSFORMER
101
+ # -----------------------
102
+ elif model_type == "transformer":
103
+ test_csv_path = os.path.join(data_path, f"{lang}_test.csv")
104
+ if not os.path.exists(test_csv_path):
105
+ raise FileNotFoundError(f"Test CSV for transformer not found: {test_csv_path}")
106
+
107
+ df_test = pd.read_csv(test_csv_path)
108
+ df_test = parse_labels_column(df_test)
109
+
110
+ # Ensure 'combo' exists
111
+ if "combo" not in df_test.columns:
112
+ df_test["combo"] = (
113
+ df_test["comment_sentence"].astype(str) + " | " + df_test["class"].astype(str)
114
+ )
115
+
116
+ texts = df_test["combo"].astype(str).tolist()
117
+ y_true = np.array(df_test["labels"].tolist(), dtype=int)
118
+
119
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
120
+ model = AutoModelForSequenceClassification.from_pretrained(model_path).to(DEVICE)
121
+ model.eval()
122
+
123
+ enc = tokenizer(
124
+ texts,
125
+ padding=True,
126
+ truncation=True,
127
+ max_length=128, # keep consistent with training config
128
+ return_tensors="pt",
129
+ )
130
+ enc = {k: v.to(DEVICE) for k, v in enc.items()}
131
+
132
+ with torch.no_grad():
133
+ with torch.profiler.profile(with_flops=True) as p:
134
+ begin = time.time()
135
+ for _ in range(10):
136
+ outputs = model(**enc)
137
+ total_runtime = time.time() - begin
138
+
139
+ logits = outputs.logits
140
+ probs = torch.sigmoid(logits)
141
+ y_pred = (probs > 0.5).long().cpu().numpy()
142
+
143
+ avg_runtime_sec = total_runtime / 10
144
+ avg_gflops = (sum(k.flops for k in p.key_averages()) / 1e9) / 10
145
+
146
+ else:
147
+ raise ValueError(f"Unsupported model_type: {model_type}")
148
+
149
+ print(f"Avg runtime in seconds: {avg_runtime_sec:.4f}")
150
+ mlflow.log_metric("avg_runtime_sec", avg_runtime_sec)
151
+ mlflow.log_metric("avg_gflops", avg_gflops)
152
+
153
+ # -----------------------
154
+ # Manual per-label metrics (common)
155
+ # -----------------------
156
+ scores = []
157
+ y_true_transposed = y_true.T
158
+ y_pred_transposed = y_pred.T
159
+
160
+ for i in range(len(y_pred_transposed)):
161
+ tp = np.logical_and(y_true_transposed[i] == 1, y_pred_transposed[i] == 1).sum()
162
+ fp = np.logical_and(y_true_transposed[i] == 0, y_pred_transposed[i] == 1).sum()
163
+ fn = np.logical_and(y_true_transposed[i] == 1, y_pred_transposed[i] == 0).sum()
164
+
165
+ precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
166
+ recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
167
+ f1 = (2 * tp) / (2 * tp + fp + fn) if (2 * tp + fp + fn) > 0 else 0.0
168
+
169
+ scores.append(
170
+ {
171
+ "lan": lang,
172
+ "cat": LABELS[lang][i],
173
+ "precision": precision,
174
+ "recall": recall,
175
+ "f1": f1,
176
+ }
177
+ )
178
+
179
+ lan_scores_df = pd.DataFrame(scores)
180
+
181
+ avg_f1 = lan_scores_df["f1"].mean()
182
+ avg_precision = lan_scores_df["precision"].mean()
183
+ avg_recall = lan_scores_df["recall"].mean()
184
+
185
+ mlflow.log_metric("avg_f1_score", avg_f1)
186
+ mlflow.log_metric("avg_precision", avg_precision)
187
+ mlflow.log_metric("avg_recall", avg_recall)
188
+
189
+ dvc_metrics = {
190
+ "avg_f1_score": avg_f1,
191
+ "avg_precision": avg_precision,
192
+ "avg_recall": avg_recall,
193
+ "avg_runtime_sec": avg_runtime_sec,
194
+ "avg_gflops": avg_gflops,
195
+ }
196
+ os.makedirs(os.path.dirname(metrics_output_path), exist_ok=True)
197
+ with open(metrics_output_path, "w") as f:
198
+ json.dump(dvc_metrics, f, indent=4)
199
+
200
+
201
+ if __name__ == "__main__":
202
+ parser = argparse.ArgumentParser()
203
+ parser.add_argument("--lang", type=str, required=True)
204
+ parser.add_argument("--model_type", type=str, required=True)
205
+ parser.add_argument(
206
+ "--data_path",
207
+ type=str,
208
+ default="data/raw",
209
+ help=(
210
+ "Path to evaluation data. "
211
+ "For setfit/random_forest: base dir with raw CSVs (e.g. data/raw). "
212
+ "For transformer: directory with {lang}_test.csv (e.g. data/processed/transformer)."
213
+ ),
214
+ )
215
+ args = parser.parse_args()
216
+
217
+ evaluate_and_benchmark(
218
+ lang=args.lang,
219
+ model_type=args.model_type,
220
+ model_path=f"models/{args.lang}/{args.model_type}",
221
+ data_path=args.data_path,
222
+ metrics_output_path=f"reports/metrics/{args.lang}/{args.model_type}_metrics.json",
223
+ )
codecommentclassification/modeling/train.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module for training different types of models for code comment classification."""
2
+
3
+ import argparse
4
+ import logging
5
+ import os
6
+
7
+ import dagshub
8
+ from datasets import Dataset
9
+ import mlflow
10
+ import yaml
11
+
12
+ from .utils import load_dataset_splits, parse_labels_column
13
+
14
+ logging.basicConfig(
15
+ level=logging.INFO,
16
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
17
+ )
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ dagshub.init(repo_owner="se4ai2526-uniba", repo_name="TheClouds", mlflow=True)
22
+
23
+
24
+ def train_model(lang, model_type, data_path, model_output_path, params):
25
+ """Trains and saves a model for a specific language and model type."""
26
+ print(f"--- Starting training for language: {lang} with model: {model_type} ---")
27
+
28
+ ds = load_dataset_splits(data_path)
29
+
30
+ train_df = ds[f"{lang}_train"]
31
+ eval_df = ds[f"{lang}_test"]
32
+
33
+ train_df = parse_labels_column(train_df)
34
+ eval_df = parse_labels_column(eval_df)
35
+
36
+ # converto i DataFrame in HuggingFace Dataset
37
+ train_dataset = Dataset.from_pandas(train_df, preserve_index=False)
38
+ eval_dataset = Dataset.from_pandas(eval_df, preserve_index=False)
39
+
40
+ if model_type == "setfit":
41
+ from setfit import SetFitModel, Trainer, TrainingArguments
42
+
43
+ mlflow.set_experiment("SetFit Training")
44
+ with mlflow.start_run(run_name=f"train-{lang}-{model_type}"):
45
+ mlflow.log_param("language", lang)
46
+ mlflow.log_param("model_type", model_type)
47
+ model = SetFitModel.from_pretrained(
48
+ "sentence-transformers/paraphrase-MiniLM-L6-v2",
49
+ multi_target_strategy="multi-output",
50
+ )
51
+ args = TrainingArguments(**params)
52
+ trainer = Trainer(
53
+ model=model,
54
+ args=args,
55
+ train_dataset=train_dataset,
56
+ eval_dataset=eval_dataset,
57
+ column_mapping={"combo": "text", "labels": "label"},
58
+ )
59
+
60
+ mlflow.log_param("num_epochs", args.num_epochs)
61
+ mlflow.log_param("num_iterations", args.num_iterations)
62
+
63
+ trainer.train()
64
+
65
+ eval_metrics = trainer.evaluate()
66
+ for metric_name, metric_value in eval_metrics.items():
67
+ mlflow.log_metric(metric_name, metric_value)
68
+
69
+ trainer.model.save_pretrained(model_output_path)
70
+
71
+ mlflow.transformers.log_model(
72
+ transformers_model=model_output_path,
73
+ artifact_path=f"{lang}_setfit_model",
74
+ task="text-classification",
75
+ )
76
+ mlflow.end_run()
77
+
78
+ elif model_type == "random_forest":
79
+ import joblib
80
+ import numpy as np
81
+ from sklearn.ensemble import RandomForestClassifier
82
+ from sklearn.feature_extraction.text import TfidfVectorizer
83
+ from sklearn.multioutput import MultiOutputClassifier
84
+ from sklearn.pipeline import Pipeline
85
+
86
+ mlflow.set_experiment("Random Forest Training")
87
+ with mlflow.start_run(run_name=f"train-{lang}-{model_type}"):
88
+ mlflow.log_param("language", lang)
89
+ mlflow.log_param("model_type", model_type)
90
+ mlflow.log_params(params)
91
+
92
+ tfidf_params = {
93
+ "ngram_range": tuple(params.pop("ngram_range", (1, 1))),
94
+ "max_features": params.pop("max_features", None),
95
+ "min_df": params.pop("min_df", 1),
96
+ "max_df": params.pop("max_df", 1.0),
97
+ }
98
+
99
+ rf_params = params
100
+ pipeline = Pipeline(
101
+ [
102
+ ("tfidf", TfidfVectorizer(**tfidf_params)),
103
+ (
104
+ "clf",
105
+ MultiOutputClassifier(
106
+ RandomForestClassifier(
107
+ random_state=42, class_weight="balanced", **rf_params
108
+ )
109
+ ),
110
+ ),
111
+ ]
112
+ )
113
+
114
+ X_train = train_dataset["combo"]
115
+ y_train = np.array(train_dataset["labels"])
116
+
117
+ pipeline.fit(X_train, y_train)
118
+
119
+ X_test = eval_dataset["combo"]
120
+ y_test = np.array(eval_dataset["labels"])
121
+
122
+ score = pipeline.score(X_test, y_test)
123
+ mlflow.log_metric("accuracy", score)
124
+
125
+ os.makedirs(os.path.dirname(model_output_path), exist_ok=True)
126
+ joblib.dump(pipeline, f"{model_output_path}.joblib")
127
+
128
+ mlflow.sklearn.log_model(
129
+ sk_model=pipeline, artifact_path=f"{lang}_random_forest_model"
130
+ )
131
+ mlflow.end_run()
132
+
133
+ elif model_type == "transformer":
134
+ from .transformer import (
135
+ TransformerConfig,
136
+ TransformerTrainer,
137
+ )
138
+
139
+ mlflow.set_experiment("Transformer Training")
140
+ with mlflow.start_run(run_name=f"train-{lang}-{model_type}"):
141
+ mlflow.log_param("language", lang)
142
+ mlflow.log_param("model_type", model_type)
143
+ mlflow.log_params(params)
144
+
145
+ cfg = TransformerConfig(
146
+ lang=lang,
147
+ raw_data_dir="data/raw",
148
+ processed_data_dir="data/processed/transformer",
149
+ model_output_path=model_output_path,
150
+ pretrained_model_name=params.get(
151
+ "pretrained_model_name", "microsoft/codebert-base"
152
+ ),
153
+ max_length=params.get("max_length", 128),
154
+ batch_size=params.get("batch_size", 16),
155
+ lr=params.get("lr", 2e-5),
156
+ num_epochs=params.get("num_epochs", 5),
157
+ warmup_ratio=params.get("warmup_ratio", 0.1),
158
+ pos_weight_cap=params.get("pos_weight_cap", 30.0),
159
+ threshold=params.get("threshold", 0.5),
160
+ preprocessing=params.get("preprocessing", False),
161
+ preprocessing_factor=params.get("preprocessing_factor", 1.0),
162
+ )
163
+
164
+ logger.info(
165
+ "Starting transformer training for language '%s' with config: %s",
166
+ lang,
167
+ cfg,
168
+ )
169
+
170
+ trainer = TransformerTrainer(cfg)
171
+ metrics = trainer.run()
172
+
173
+ logger.info("Final transformer metrics for %s: %s", lang, metrics)
174
+
175
+ for name, value in metrics.items():
176
+ mlflow.log_metric(f"final_{name}", value)
177
+
178
+ mlflow.end_run()
179
+
180
+ else:
181
+ raise ValueError(f"Unsupported model_type: {model_type}")
182
+
183
+ print(f"Model for {lang}-{model_type} saved to {model_output_path}")
184
+
185
+
186
+ if __name__ == "__main__":
187
+ parser = argparse.ArgumentParser()
188
+ parser.add_argument("--lang", type=str, required=True)
189
+ parser.add_argument("--model_type", type=str, required=True)
190
+ args = parser.parse_args()
191
+
192
+ with open("params.yaml", "r") as f:
193
+ all_params = yaml.safe_load(f)
194
+
195
+ model_params = all_params[args.model_type].copy()
196
+
197
+ train_model(
198
+ lang=args.lang,
199
+ model_type=args.model_type,
200
+ data_path="data/raw",
201
+ model_output_path=f"models/{args.lang}/{args.model_type}",
202
+ params=model_params,
203
+ )
codecommentclassification/modeling/transformer/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ """Transformer model trainer module."""
2
+
3
+ import logging
4
+
5
+ from .trainer import TransformerConfig, TransformerTrainer
6
+
7
+ logger = logging.getLogger(__name__)
8
+ logger.addHandler(logging.NullHandler())
9
+
10
+ __all__ = ["TransformerConfig", "TransformerTrainer"]
codecommentclassification/modeling/transformer/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (562 Bytes). View file
 
codecommentclassification/modeling/transformer/__pycache__/preprocessing.cpython-311.pyc ADDED
Binary file (10.8 kB). View file
 
codecommentclassification/modeling/transformer/__pycache__/trainer.cpython-311.pyc ADDED
Binary file (26.7 kB). View file
 
codecommentclassification/modeling/transformer/preprocessing.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Preprocessing helpers for transformer training.
2
+
3
+ This module provides utilities to parse multi-label strings, ensure the
4
+ `combo` column exists, perform label-aware supersampling of a training
5
+ DataFrame, and a light-weight `load_or_prepare_data` entrypoint that loads
6
+ raw CSVs, optionally applies preprocessing, and writes processed CSVs.
7
+ """
8
+
9
+ import logging
10
+ import os
11
+ from typing import Tuple
12
+
13
+ import numpy as np
14
+ import pandas as pd
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ def parse_label_str(s: str) -> np.ndarray:
20
+ """Convert a string like '[0 0 1 0 0 0 0]' into a float32 numpy array."""
21
+ return np.fromstring(str(s).strip("[]"), sep=" ", dtype=np.float32)
22
+
23
+
24
+ def ensure_combo_column(df: pd.DataFrame) -> pd.DataFrame:
25
+ """Ensure that the 'combo' column exists.
26
+
27
+ If missing, create it from 'comment_sentence' and 'class'.
28
+ """
29
+ if "combo" not in df.columns:
30
+ logger.info("Column 'combo' not found, creating it from 'comment_sentence' and 'class'.")
31
+ df = df.copy()
32
+ df["combo"] = df["comment_sentence"].astype(str) + " | " + df["class"].astype(str)
33
+ else:
34
+ logger.info("Column 'combo' already present, reusing it.")
35
+ return df
36
+
37
+
38
+ def supersample_dataframe(
39
+ df: pd.DataFrame,
40
+ factor: float,
41
+ random_state: int = 42,
42
+ ) -> pd.DataFrame:
43
+ """Offline label-aware supersampling of the training DataFrame.
44
+
45
+ - Keeps all original rows.
46
+ - For each label j, duplicates rows that contain that label until:
47
+ target_j = min(max_freq, freq_j * factor)
48
+ where freq_j is the original count for label j and max_freq is the
49
+ maximum frequency across labels.
50
+ - Shuffles the resulting indices.
51
+
52
+ Assumes:
53
+ - df['labels'] is a string representation of a multi-hot vector.
54
+ """
55
+ if factor <= 1.0:
56
+ logger.info(
57
+ "Supersampling factor <= 1.0 (%.2f), returning original DataFrame.",
58
+ factor,
59
+ )
60
+ return df.copy()
61
+
62
+ rng = np.random.default_rng(random_state)
63
+
64
+ labels_array = np.stack(df["labels"].map(parse_label_str).values)
65
+ if labels_array.ndim == 1:
66
+ labels_array = labels_array[:, None]
67
+
68
+ num_samples, num_labels = labels_array.shape
69
+ freq = labels_array.sum(axis=0).astype(int)
70
+ max_freq = int(freq.max())
71
+
72
+ logger.info("Original label frequencies: %s", freq.tolist())
73
+ logger.info("Max label frequency: %d", max_freq)
74
+
75
+ if max_freq == 0:
76
+ logger.warning("All label frequencies are zero, skipping supersampling.")
77
+ return df.copy()
78
+
79
+ target = np.minimum(max_freq, (freq * factor).astype(int))
80
+ logger.info(
81
+ "Target label frequencies after supersampling (capped by max_freq): %s",
82
+ target.tolist(),
83
+ )
84
+
85
+ indices_by_label = {j: np.where(labels_array[:, j] == 1)[0] for j in range(num_labels)}
86
+
87
+ new_indices = list(range(num_samples))
88
+
89
+ for j in range(num_labels):
90
+ current = int(freq[j])
91
+ desired = int(target[j])
92
+ if desired <= current:
93
+ continue
94
+
95
+ candidate_indices = indices_by_label[j]
96
+ if candidate_indices.size == 0:
97
+ continue
98
+
99
+ needed = desired - current
100
+ extra = rng.choice(candidate_indices, size=needed, replace=True)
101
+ new_indices.extend(extra.tolist())
102
+ logger.info(
103
+ "Label %d: current=%d, target=%d, added=%d samples.",
104
+ j,
105
+ current,
106
+ desired,
107
+ needed,
108
+ )
109
+
110
+ rng.shuffle(new_indices)
111
+ df_sup = df.iloc[new_indices].reset_index(drop=True)
112
+
113
+ labels_array_after = np.stack(df_sup["labels"].map(parse_label_str).values)
114
+ freq_after = labels_array_after.sum(axis=0).astype(int)
115
+ logger.info("Final label frequencies after supersampling: %s", freq_after.tolist())
116
+ logger.info("Training rows before: %d, after: %d", num_samples, len(df_sup))
117
+
118
+ return df_sup
119
+
120
+
121
+ def load_or_prepare_data(
122
+ lang: str,
123
+ raw_data_dir: str,
124
+ processed_data_dir: str,
125
+ preprocessing_enabled: bool,
126
+ preprocessing_factor: float,
127
+ random_state: int = 42,
128
+ ) -> Tuple[pd.DataFrame, pd.DataFrame, str]:
129
+ """Load raw CSVs for the given language, optionally apply preprocessing.
130
+
131
+ (supersampling) on the train split, and save processed CSVs.
132
+
133
+ - Test split is NEVER supersampled or augmented.
134
+ - Train split:
135
+ - always gets 'combo' and 'labels_array'
136
+ - supersampled only if preprocessing_enabled=True and preprocessing_factor>1.0
137
+
138
+ Parameters
139
+ ----------
140
+ lang : str
141
+ Language key (e.g., 'java', 'python', 'pharo').
142
+ raw_data_dir : str
143
+ Directory containing {lang}_train.csv and {lang}_test.csv.
144
+ processed_data_dir : str
145
+ Directory where processed CSVs will be saved.
146
+ preprocessing_enabled : bool
147
+ Whether to apply supersampling on the training split.
148
+ preprocessing_factor : float
149
+ Supersampling factor (ignored if preprocessing_enabled=False).
150
+ random_state : int
151
+ RNG seed.
152
+
153
+ Returns
154
+ -------
155
+ train_df : pd.DataFrame
156
+ eval_df : pd.DataFrame
157
+ preprocessing_used : str
158
+ One of: 'none', 'supersampling'.
159
+
160
+ """
161
+ logger.info("Loading raw CSVs for language '%s' from '%s'.", lang, raw_data_dir)
162
+ raw_train_path = os.path.join(raw_data_dir, f"{lang}_train.csv")
163
+ raw_eval_path = os.path.join(raw_data_dir, f"{lang}_test.csv")
164
+
165
+ if not os.path.exists(raw_train_path):
166
+ raise FileNotFoundError(f"Raw train CSV not found: {raw_train_path}")
167
+ if not os.path.exists(raw_eval_path):
168
+ raise FileNotFoundError(f"Raw test CSV not found: {raw_eval_path}")
169
+
170
+ train_df = pd.read_csv(raw_train_path)
171
+ eval_df = pd.read_csv(raw_eval_path)
172
+
173
+ train_df = ensure_combo_column(train_df)
174
+ eval_df = ensure_combo_column(eval_df)
175
+
176
+ if preprocessing_enabled and preprocessing_factor > 1.0:
177
+ logger.info(
178
+ "Preprocessing enabled: applying supersampling with factor=%.2f.",
179
+ preprocessing_factor,
180
+ )
181
+ train_df = supersample_dataframe(
182
+ train_df,
183
+ factor=preprocessing_factor,
184
+ random_state=random_state,
185
+ )
186
+ preprocessing_used = "supersampling"
187
+ else:
188
+ logger.info(
189
+ "Preprocessing disabled or factor <= 1.0 (%.2f). Using original training data.",
190
+ preprocessing_factor,
191
+ )
192
+ preprocessing_used = "none"
193
+
194
+ # Save processed CSVs (for inspection / reproducibility)
195
+ os.makedirs(processed_data_dir, exist_ok=True)
196
+ processed_train_path = os.path.join(processed_data_dir, f"{lang}_train.csv")
197
+ processed_eval_path = os.path.join(processed_data_dir, f"{lang}_test.csv")
198
+ train_df.to_csv(processed_train_path, index=False)
199
+ eval_df.to_csv(processed_eval_path, index=False)
200
+ logger.info("Saved processed train/test CSVs to '%s'.", processed_data_dir)
201
+
202
+ # Ensure 'labels_array' exists for both splits
203
+ for df, split_name in ((train_df, "train"), (eval_df, "test")):
204
+ if "labels_array" not in df.columns:
205
+ logger.info("Parsing label strings into arrays for split '%s'.", split_name)
206
+ df["labels_array"] = df["labels"].apply(parse_label_str)
207
+
208
+ return train_df, eval_df, preprocessing_used
codecommentclassification/modeling/transformer/trainer.py ADDED
@@ -0,0 +1,531 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Training utilities for transformer-based multi-label classification.
2
+
3
+ This module contains a small training harness around HuggingFace
4
+ `AutoModelForSequenceClassification` specialized for the project's
5
+ multi-label code-comment classification task. It provides:
6
+
7
+ - `TransformerConfig` dataclass for configurable training runs.
8
+ - `CommentDataset` to wrap tokenization of pandas DataFrames.
9
+ - `TransformerTrainer` which runs the training loop, evaluation and
10
+ model export (with MLflow logging hooks).
11
+
12
+ The helpers are intended for experimental, small-scale training and
13
+ instrumentation rather than production-grade distributed training.
14
+ """
15
+
16
+ from dataclasses import asdict, dataclass
17
+ import logging
18
+ import os
19
+ from typing import Dict, List, Tuple
20
+
21
+ import mlflow
22
+ import numpy as np
23
+ import pandas as pd
24
+ from sklearn.metrics import (
25
+ accuracy_score,
26
+ classification_report,
27
+ f1_score,
28
+ precision_score,
29
+ recall_score,
30
+ )
31
+ import torch
32
+ from torch.utils.data import DataLoader, Dataset
33
+ from tqdm.auto import tqdm
34
+ from transformers import (
35
+ AutoModelForSequenceClassification,
36
+ AutoTokenizer,
37
+ get_linear_schedule_with_warmup,
38
+ )
39
+
40
+ from .preprocessing import load_or_prepare_data
41
+
42
+ logger = logging.getLogger(__name__)
43
+
44
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
45
+ print(f"Using device: {DEVICE}")
46
+
47
+
48
+ # Label names per language, order must match the label vector in the CSV
49
+ LABELS: Dict[str, Tuple[str, ...]] = {
50
+ "java": (
51
+ "summary",
52
+ "Ownership",
53
+ "Expand",
54
+ "usage",
55
+ "Pointer",
56
+ "deprecation",
57
+ "rational",
58
+ ),
59
+ "python": (
60
+ "Usage",
61
+ "Parameters",
62
+ "DevelopmentNotes",
63
+ "Expand",
64
+ "Summary",
65
+ ),
66
+ "pharo": (
67
+ "Keyimplementationpoints",
68
+ "Example",
69
+ "Responsibilities",
70
+ "Intent",
71
+ "Keymessages",
72
+ "Collaborators",
73
+ ),
74
+ }
75
+
76
+
77
+ @dataclass
78
+ class TransformerConfig:
79
+ """Configuration for transformer training runs.
80
+
81
+ Attributes are intentionally simple dataclass fields and map directly to
82
+ CLI/YAML configuration keys used by the training harness.
83
+ """
84
+
85
+ lang: str
86
+ raw_data_dir: str
87
+ processed_data_dir: str
88
+ model_output_path: str
89
+ pretrained_model_name: str = "microsoft/codebert-base"
90
+ max_length: int = 128
91
+ batch_size: int = 16
92
+ lr: float = 2e-5
93
+ num_epochs: int = 5
94
+ warmup_ratio: float = 0.1
95
+ pos_weight_cap: float = 30.0
96
+ threshold: float = 0.5
97
+ preprocessing: bool = False
98
+ preprocessing_factor: float = 1.0
99
+
100
+ def __post_init__(self) -> None:
101
+ """Force correct types even if YAML provides strings."""
102
+ self.max_length = int(self.max_length)
103
+ self.batch_size = int(self.batch_size)
104
+ self.lr = float(self.lr)
105
+ self.num_epochs = int(self.num_epochs)
106
+ self.warmup_ratio = float(self.warmup_ratio)
107
+ self.pos_weight_cap = float(self.pos_weight_cap)
108
+ self.threshold = float(self.threshold)
109
+ self.preprocessing_factor = float(self.preprocessing_factor)
110
+
111
+ # allow 'true'/'false' as strings from YAML
112
+ if isinstance(self.preprocessing, str):
113
+ self.preprocessing = self.preprocessing.lower() == "true"
114
+
115
+
116
+ class CommentDataset(Dataset):
117
+ """Simple Dataset wrapper around a pandas DataFrame with 'combo' and 'labels_array'."""
118
+
119
+ def __init__(self, df: pd.DataFrame, tokenizer: AutoTokenizer, max_length: int):
120
+ """Create a dataset that tokenizes rows on demand.
121
+
122
+ Parameters
123
+ ----------
124
+ df : pandas.DataFrame
125
+ Input frame containing at least `combo` and `labels_array` columns.
126
+ tokenizer : transformers.AutoTokenizer
127
+ Tokenizer used to encode text into model inputs.
128
+ max_length : int
129
+ Maximum tokenization length (used for padding/truncation).
130
+
131
+ """
132
+ self.df = df.reset_index(drop=True)
133
+ self.tokenizer = tokenizer
134
+ self.max_length = max_length
135
+
136
+ def __len__(self) -> int:
137
+ """Return the number of examples in the dataset."""
138
+ return len(self.df)
139
+
140
+ def __getitem__(self, idx: int):
141
+ """Return a single tokenized example and its labels as tensors.
142
+
143
+ The returned dict contains tokenized inputs (PyTorch tensors) and a
144
+ `labels` tensor suitable for BCEWithLogitsLoss for multi-label tasks.
145
+ """
146
+ row = self.df.iloc[idx]
147
+ text = str(row["combo"])
148
+ labels = np.asarray(row["labels_array"], dtype=np.float32)
149
+
150
+ enc = self.tokenizer(
151
+ text,
152
+ truncation=True,
153
+ max_length=self.max_length,
154
+ padding="max_length",
155
+ return_tensors="pt",
156
+ )
157
+
158
+ item = {k: v.squeeze(0) for k, v in enc.items()}
159
+ item["labels"] = torch.from_numpy(labels)
160
+ return item
161
+
162
+
163
+ class TransformerTrainer:
164
+ """End-to-end transformer trainer for the code comment multi-label task."""
165
+
166
+ def __init__(self, cfg: TransformerConfig) -> None:
167
+ """Initialize training state, data loaders, model and optimizer.
168
+
169
+ Parameters
170
+ ----------
171
+ cfg : TransformerConfig
172
+ Training configuration containing data paths and hyperparameters.
173
+
174
+ """
175
+ self.cfg = cfg
176
+ if cfg.lang not in LABELS:
177
+ raise ValueError(f"No LABELS defined for language '{cfg.lang}'.")
178
+
179
+ self.label_names = LABELS[cfg.lang]
180
+ self.num_labels = len(self.label_names)
181
+
182
+ logger.info("Initializing TransformerTrainer for language '%s'.", cfg.lang)
183
+ logger.info("Raw data directory: %s", cfg.raw_data_dir)
184
+ logger.info("Processed data directory: %s", cfg.processed_data_dir)
185
+ logger.info("Model output path: %s", cfg.model_output_path)
186
+
187
+ # --- data loading / preprocessing ---
188
+ self.train_df, self.eval_df, self.preprocessing_used = load_or_prepare_data(
189
+ lang=cfg.lang,
190
+ raw_data_dir=cfg.raw_data_dir,
191
+ processed_data_dir=cfg.processed_data_dir,
192
+ preprocessing_enabled=cfg.preprocessing,
193
+ preprocessing_factor=cfg.preprocessing_factor,
194
+ random_state=42,
195
+ )
196
+
197
+ logger.info("Preprocessing used for this run: %s", self.preprocessing_used)
198
+ logger.info("Using device: %s", DEVICE)
199
+ logger.info(
200
+ "Train size: %d rows, Eval size: %d rows",
201
+ len(self.train_df),
202
+ len(self.eval_df),
203
+ )
204
+
205
+ # --- log config and dataset info to MLflow ---
206
+ try:
207
+ cfg_dict = asdict(self.cfg)
208
+ mlflow.log_params({f"cfg_{k}": v for k, v in cfg_dict.items()})
209
+ mlflow.log_param("num_labels", self.num_labels)
210
+ mlflow.log_param("label_names", ",".join(self.label_names))
211
+ mlflow.log_param("train_samples", len(self.train_df))
212
+ mlflow.log_param("eval_samples", len(self.eval_df))
213
+ mlflow.log_param("preprocessing_used", self.preprocessing_used)
214
+ except Exception as e:
215
+ logger.warning("Could not log transformer config to MLflow: %s", e)
216
+
217
+ # tokenizer
218
+ logger.info("Loading tokenizer '%s'.", cfg.pretrained_model_name)
219
+ self.tokenizer = AutoTokenizer.from_pretrained(cfg.pretrained_model_name)
220
+
221
+ # label statistics and pos_weight
222
+ y_train = np.stack(self.train_df["labels_array"].to_numpy())
223
+ self.pos_weight = self._compute_pos_weight(y_train)
224
+
225
+ # dataloaders
226
+ train_dataset = CommentDataset(self.train_df, self.tokenizer, cfg.max_length)
227
+ eval_dataset = CommentDataset(self.eval_df, self.tokenizer, cfg.max_length)
228
+
229
+ self.train_loader = DataLoader(
230
+ train_dataset,
231
+ batch_size=cfg.batch_size,
232
+ shuffle=True,
233
+ )
234
+ self.eval_loader = DataLoader(
235
+ eval_dataset,
236
+ batch_size=cfg.batch_size,
237
+ shuffle=False,
238
+ )
239
+
240
+ logger.info(
241
+ "Hyperparameters – lr=%s (type=%s), batch_size=%s, num_epochs=%s",
242
+ self.cfg.lr,
243
+ type(self.cfg.lr),
244
+ self.cfg.batch_size,
245
+ self.cfg.num_epochs,
246
+ )
247
+
248
+ # model
249
+ logger.info("Loading base model '%s'.", cfg.pretrained_model_name)
250
+ self.model = AutoModelForSequenceClassification.from_pretrained(
251
+ cfg.pretrained_model_name,
252
+ num_labels=self.num_labels,
253
+ problem_type="multi_label_classification",
254
+ ).to(DEVICE)
255
+
256
+ self.loss_fn = torch.nn.BCEWithLogitsLoss(pos_weight=self.pos_weight.to(DEVICE))
257
+ self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.cfg.lr)
258
+
259
+ num_training_steps = cfg.num_epochs * len(self.train_loader)
260
+ num_warmup_steps = int(cfg.warmup_ratio * num_training_steps)
261
+ logger.info(
262
+ "Total training steps: %d, warmup steps: %d.",
263
+ num_training_steps,
264
+ num_warmup_steps,
265
+ )
266
+
267
+ self.scheduler = get_linear_schedule_with_warmup(
268
+ self.optimizer,
269
+ num_warmup_steps=num_warmup_steps,
270
+ num_training_steps=num_training_steps,
271
+ )
272
+
273
+ self.best_state_dict = None
274
+ self.best_val_macro_f1 = 0.0
275
+
276
+ def _compute_pos_weight(self, y: np.ndarray) -> torch.Tensor:
277
+ if y.ndim == 1:
278
+ y = y[:, None]
279
+ freq = y.sum(axis=0).astype(np.float64)
280
+ num_samples = y.shape[0]
281
+
282
+ pos_weight = (num_samples - freq) / np.clip(freq, 1.0, None)
283
+ pos_weight = np.clip(pos_weight, 1.0, self.cfg.pos_weight_cap)
284
+
285
+ logger.info("Positive class weights (clipped): %s", pos_weight.tolist())
286
+ return torch.tensor(pos_weight, dtype=torch.float32)
287
+
288
+ def _step_batch(self, batch, train: bool):
289
+ batch = {k: v.to(DEVICE) for k, v in batch.items()}
290
+ labels = batch.pop("labels")
291
+
292
+ outputs = self.model(**batch)
293
+ logits = outputs.logits
294
+ loss = self.loss_fn(logits, labels)
295
+
296
+ if train:
297
+ loss.backward()
298
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
299
+ self.optimizer.step()
300
+ self.scheduler.step()
301
+ self.optimizer.zero_grad()
302
+
303
+ return loss, logits, labels
304
+
305
+ def train_one_epoch(self, epoch: int) -> float:
306
+ """Run a single training epoch over `self.train_loader`.
307
+
308
+ Returns
309
+ -------
310
+ float
311
+ The average training loss over the epoch.
312
+
313
+ """
314
+ self.model.train()
315
+ total_loss = 0.0
316
+ n_samples = 0
317
+
318
+ num_batches = len(self.train_loader)
319
+ logger.info("Starting epoch %d training. Number of batches: %d", epoch, num_batches)
320
+
321
+ progress_bar = tqdm(
322
+ self.train_loader,
323
+ desc=f"Epoch {epoch} [train]",
324
+ total=num_batches,
325
+ leave=False,
326
+ )
327
+
328
+ for step, batch in enumerate(progress_bar, start=1):
329
+ loss, _, _ = self._step_batch(batch, train=True)
330
+ batch_size = batch["input_ids"].size(0)
331
+ total_loss += loss.item() * batch_size
332
+ n_samples += batch_size
333
+
334
+ avg_loss_so_far = total_loss / max(n_samples, 1)
335
+ progress_bar.set_postfix({"loss": f"{avg_loss_so_far:.4f}"})
336
+
337
+ avg_loss = total_loss / max(n_samples, 1)
338
+ logger.info("Epoch %d training completed. Average loss: %.4f.", epoch, avg_loss)
339
+
340
+ mlflow.log_metric("train_loss", avg_loss, step=epoch)
341
+
342
+ return avg_loss
343
+
344
+ def evaluate(
345
+ self,
346
+ epoch: int,
347
+ split_name: str = "eval",
348
+ ) -> Tuple[float, float, float, np.ndarray, np.ndarray]:
349
+ """Evaluate the model on `self.eval_loader` and compute metrics.
350
+
351
+ Parameters
352
+ ----------
353
+ epoch : int
354
+ Current epoch number (used for logging).
355
+ split_name : str
356
+ Name of the evaluation split used for MLflow metric keys.
357
+
358
+ Returns
359
+ -------
360
+ tuple
361
+ (avg_loss, micro_f1, macro_f1, y_true, y_pred)
362
+
363
+ """
364
+ self.model.eval()
365
+ total_loss = 0.0
366
+ n_samples = 0
367
+ all_preds: List[np.ndarray] = []
368
+ all_labels: List[np.ndarray] = []
369
+
370
+ logger.info("Starting evaluation for epoch %d on split '%s'.", epoch, split_name)
371
+
372
+ num_batches = len(self.eval_loader)
373
+ progress_bar = tqdm(
374
+ self.eval_loader,
375
+ desc=f"Epoch {epoch} [{split_name}]",
376
+ total=num_batches,
377
+ leave=False,
378
+ )
379
+
380
+ with torch.no_grad():
381
+ for batch in progress_bar:
382
+ loss, logits, labels = self._step_batch(batch, train=False)
383
+ batch_size = logits.size(0)
384
+ total_loss += loss.item() * batch_size
385
+ n_samples += batch_size
386
+
387
+ probs = torch.sigmoid(logits)
388
+ preds = (probs > self.cfg.threshold).long()
389
+
390
+ all_preds.append(preds.cpu().numpy())
391
+ all_labels.append(labels.cpu().numpy())
392
+
393
+ avg_loss_so_far = total_loss / max(n_samples, 1)
394
+ progress_bar.set_postfix({"loss": f"{avg_loss_so_far:.4f}"})
395
+
396
+ avg_loss = total_loss / max(n_samples, 1)
397
+ y_pred = np.concatenate(all_preds, axis=0)
398
+ y_true = np.concatenate(all_labels, axis=0)
399
+
400
+ # F1
401
+ micro_f1 = f1_score(y_true, y_pred, average="micro", zero_division=0)
402
+ macro_f1 = f1_score(y_true, y_pred, average="macro", zero_division=0)
403
+
404
+ # Precision
405
+ micro_precision = precision_score(y_true, y_pred, average="micro", zero_division=0)
406
+ macro_precision = precision_score(y_true, y_pred, average="macro", zero_division=0)
407
+
408
+ # Recall
409
+ micro_recall = recall_score(y_true, y_pred, average="micro", zero_division=0)
410
+ macro_recall = recall_score(y_true, y_pred, average="macro", zero_division=0)
411
+
412
+ # Accuracy (multi-label)
413
+ # subset_accuracy = exact match of all labels for each sample
414
+ subset_accuracy = accuracy_score(y_true, y_pred)
415
+ # micro_accuracy = accuracy over flattened label indicators
416
+ micro_accuracy = accuracy_score(y_true.flatten(), y_pred.flatten())
417
+
418
+ logger.info(
419
+ "Eval results [%s] - loss: %.4f | "
420
+ "micro-F1: %.4f, macro-F1: %.4f | "
421
+ "micro-P: %.4f, macro-P: %.4f | "
422
+ "micro-R: %.4f, macro-R: %.4f | "
423
+ "subset-acc: %.4f, micro-acc: %.4f",
424
+ split_name,
425
+ avg_loss,
426
+ micro_f1,
427
+ macro_f1,
428
+ micro_precision,
429
+ macro_precision,
430
+ micro_recall,
431
+ macro_recall,
432
+ subset_accuracy,
433
+ micro_accuracy,
434
+ )
435
+
436
+ # MLflow logging (per epoch)
437
+ mlflow.log_metric(f"{split_name}_loss", avg_loss, step=epoch)
438
+ mlflow.log_metric(f"{split_name}_micro_f1", micro_f1, step=epoch)
439
+ mlflow.log_metric(f"{split_name}_macro_f1", macro_f1, step=epoch)
440
+ mlflow.log_metric(f"{split_name}_micro_precision", micro_precision, step=epoch)
441
+ mlflow.log_metric(f"{split_name}_macro_precision", macro_precision, step=epoch)
442
+ mlflow.log_metric(f"{split_name}_micro_recall", micro_recall, step=epoch)
443
+ mlflow.log_metric(f"{split_name}_macro_recall", macro_recall, step=epoch)
444
+ mlflow.log_metric(f"{split_name}_subset_accuracy", subset_accuracy, step=epoch)
445
+ mlflow.log_metric(f"{split_name}_micro_accuracy", micro_accuracy, step=epoch)
446
+
447
+ return avg_loss, micro_f1, macro_f1, y_true, y_pred
448
+
449
+ def run(self) -> Dict[str, float]:
450
+ """Execute the full training loop and save the best model.
451
+
452
+ Returns
453
+ -------
454
+ dict
455
+ Summary metrics from the final evaluation (micro/macro F1).
456
+
457
+ """
458
+ logger.info("Starting training loop for %d epochs.", self.cfg.num_epochs)
459
+ for epoch in range(1, self.cfg.num_epochs + 1):
460
+ train_loss = self.train_one_epoch(epoch)
461
+ val_loss, val_micro_f1, val_macro_f1, _, _ = self.evaluate(epoch, split_name="eval")
462
+
463
+ logger.info(
464
+ "[%s] epoch=%d train_loss=%.4f val_loss=%.4f val_micro_f1=%.4f val_macro_f1=%.4f",
465
+ self.cfg.lang,
466
+ epoch,
467
+ train_loss,
468
+ val_loss,
469
+ val_micro_f1,
470
+ val_macro_f1,
471
+ )
472
+
473
+ if val_macro_f1 > self.best_val_macro_f1:
474
+ logger.info(
475
+ "New best macro-F1: %.4f (previous: %.4f). Saving current model state.",
476
+ val_macro_f1,
477
+ self.best_val_macro_f1,
478
+ )
479
+ self.best_val_macro_f1 = val_macro_f1
480
+ self.best_state_dict = {k: v.cpu() for k, v in self.model.state_dict().items()}
481
+
482
+ if self.best_state_dict is not None:
483
+ logger.info("Loading best model weights (macro-F1 = %.4f).", self.best_val_macro_f1)
484
+ self.model.load_state_dict(self.best_state_dict)
485
+
486
+ # final evaluation
487
+ _, micro_f1, macro_f1, y_true, y_pred = self.evaluate(
488
+ epoch=self.cfg.num_epochs,
489
+ split_name="eval",
490
+ )
491
+
492
+ logger.info(
493
+ "[%s] FINAL micro-F1 = %.4f, macro-F1 = %.4f.",
494
+ self.cfg.lang,
495
+ micro_f1,
496
+ macro_f1,
497
+ )
498
+ logger.info(
499
+ "Per-label classification report:\n%s",
500
+ classification_report(y_true, y_pred, target_names=self.label_names, zero_division=0),
501
+ )
502
+
503
+ # save model and tokenizer
504
+ os.makedirs(self.cfg.model_output_path, exist_ok=True)
505
+ logger.info("Saving model and tokenizer to '%s'.", self.cfg.model_output_path)
506
+ self.model.save_pretrained(self.cfg.model_output_path)
507
+ self.tokenizer.save_pretrained(self.cfg.model_output_path)
508
+
509
+ # log model directory as MLflow artifact
510
+ logger.info("Logging final model artifacts to MLflow.")
511
+ mlflow.log_artifacts(
512
+ self.cfg.model_output_path,
513
+ artifact_path=f"{self.cfg.lang}_transformer_model",
514
+ )
515
+
516
+ logger.info("Logging HF transformers model to MLflow via mlflow.transformers.log_model.")
517
+ model_info = mlflow.transformers.log_model(
518
+ transformers_model=self.cfg.model_output_path,
519
+ artifact_path=f"{self.cfg.lang}_transformer_model",
520
+ task="text-classification",
521
+ )
522
+
523
+ logger.info(
524
+ "Logged transformers model to MLflow with URI: %s",
525
+ model_info.model_uri,
526
+ )
527
+
528
+ return {
529
+ "micro_f1": float(micro_f1),
530
+ "macro_f1": float(macro_f1),
531
+ }
codecommentclassification/modeling/utils.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility functions for model training and evaluation."""
2
+
3
+ import os
4
+ from typing import List
5
+
6
+ LANGS: List[str] = ["java", "python", "pharo"]
7
+
8
+
9
+ def load_dataset_splits(base_dir=None, langs=None):
10
+ """Load dataset splits from CSV files under data/raw.
11
+
12
+ Expects files like data/raw/java_train.csv, data/raw/java_test.csv, etc.
13
+ Returns a dict mapping split names (e.g. "java_test") to pandas DataFrames.
14
+
15
+ Raises:
16
+ FileNotFoundError: se la directory base o un file atteso non esiste.
17
+ ImportError: se pandas non è installato.
18
+
19
+ """
20
+ if base_dir is None:
21
+ base_dir = os.path.join("data", "raw")
22
+
23
+ if langs is None:
24
+ langs = LANGS
25
+
26
+ if not os.path.isdir(base_dir):
27
+ raise FileNotFoundError(
28
+ f"CSV datasets not found under {base_dir}; cannot load dataset splits."
29
+ )
30
+
31
+ try:
32
+ import pandas as pd
33
+ except Exception as e:
34
+ raise ImportError("pandas is required to load dataset splits") from e
35
+
36
+ datasets = {}
37
+ for lang in langs:
38
+ for split in ("train", "test"):
39
+ fname = f"{lang}_{split}.csv"
40
+ path = os.path.join(base_dir, fname)
41
+ if not os.path.isfile(path):
42
+ raise FileNotFoundError(f"Expected dataset file missing: {path}")
43
+ df = pd.read_csv(path)
44
+ datasets[f"{lang}_{split}"] = df
45
+
46
+ return datasets
47
+
48
+
49
+ def parse_labels_column(df):
50
+ """Parse the 'labels' column of a DataFrame into lists of integers."""
51
+
52
+ def _parse_one(x):
53
+ if isinstance(x, str):
54
+ s = x.strip()
55
+ if s.startswith("[") and s.endswith("]"):
56
+ s = s[1:-1]
57
+ return [int(tok) for tok in s.split() if tok]
58
+ try:
59
+ import numpy as np
60
+
61
+ if isinstance(x, np.ndarray):
62
+ return [int(v) for v in x.tolist()]
63
+ except ImportError:
64
+ pass
65
+ if isinstance(x, (list, tuple)):
66
+ return [int(v) for v in x]
67
+ raise ValueError(f"Formato labels non gestito: {type(x)} -> {x!r}")
68
+
69
+ df["labels"] = df["labels"].apply(_parse_one)
70
+ return df
codecommentclassification/predictor.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Prediction helpers for different model types.
2
+
3
+ This module provides `ModelPredictor`, a lightweight wrapper that unifies
4
+ inference for SetFit, scikit-learn RandomForest pipelines, and HuggingFace
5
+ transformer sequence classification models. It standardizes inputs/outputs
6
+ to a NumPy array of shape (n_samples, n_labels).
7
+ """
8
+
9
+ import os
10
+ from typing import List, Union
11
+
12
+ import joblib
13
+ import numpy as np
14
+ from setfit import SetFitModel
15
+ import torch
16
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
17
+
18
+ TextInput = Union[str, List[str]]
19
+
20
+
21
+ class ModelPredictor:
22
+ """Unified predictor for SetFit, Random Forest and Transformer models.
23
+
24
+ Expected directory layout:
25
+
26
+ models/
27
+ ├── java/
28
+ │ ├── setfit/ # SetFit saved model directory
29
+ │ ├── random_forest.joblib # sklearn pipeline
30
+ │ └── transformer/ # HF model + tokenizer (config.json, etc.)
31
+ ├── python/
32
+ │ ├── setfit/
33
+ │ ├── random_forest.joblib
34
+ │ └── transformer/
35
+ └── pharo/
36
+ ├── setfit/
37
+ ├── random_forest.joblib
38
+ └── transformer/
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ lang: str,
44
+ model_type: str,
45
+ model_root: str = "models",
46
+ threshold: float = 0.5,
47
+ max_length: int = 128,
48
+ ) -> None:
49
+ """Parameters
50
+
51
+ ----------
52
+ lang : str
53
+ One of {"java", "python", "pharo"}.
54
+ model_type : str
55
+ One of {"setfit", "random_forest", "transformer"}.
56
+ model_root : str
57
+ Root directory where models are stored.
58
+ threshold : float
59
+ Decision threshold for multi-label Transformer predictions.
60
+ Ignored for SetFit and Random Forest (they already output labels).
61
+ max_length : int
62
+ Max sequence length for Transformer tokenization.
63
+
64
+ """
65
+ self.lang = lang
66
+ self.model_type = model_type
67
+ self.model_root = model_root
68
+ self.threshold = float(threshold)
69
+ self.max_length = int(max_length)
70
+
71
+ # device only matters for Transformer
72
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
73
+
74
+ if model_type == "setfit":
75
+ model_path = os.path.join(self.model_root, self.lang, "setfit")
76
+ if not os.path.isdir(model_path):
77
+ raise FileNotFoundError(f"SetFit model not found at: {model_path}")
78
+ self.model = SetFitModel.from_pretrained(model_path)
79
+
80
+ elif model_type == "random_forest":
81
+ model_path = os.path.join(self.model_root, self.lang, "random_forest.joblib")
82
+ if not os.path.isfile(model_path):
83
+ raise FileNotFoundError(f"Random Forest model not found at: {model_path}")
84
+ self.model = joblib.load(model_path)
85
+
86
+ elif model_type == "transformer":
87
+ model_path = os.path.join(self.model_root, self.lang, "transformer")
88
+ if not os.path.isdir(model_path):
89
+ raise FileNotFoundError(f"Transformer model not found at: {model_path}")
90
+
91
+ # load tokenizer and model from the same directory used during training
92
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
93
+ self.model = AutoModelForSequenceClassification.from_pretrained(model_path).to(
94
+ self.device
95
+ )
96
+ self.model.eval()
97
+
98
+ else:
99
+ raise ValueError(f"Unsupported model_type: {model_type}")
100
+
101
+ def predict(self, texts: TextInput) -> np.ndarray:
102
+ """Run prediction on one or many text samples.
103
+
104
+ Parameters
105
+ ----------
106
+ texts : str | list[str]
107
+ A single text or a list of texts.
108
+
109
+ Returns
110
+ -------
111
+ np.ndarray
112
+ Array of shape (n_samples, n_labels) with integer (typically binary) values.
113
+
114
+ """
115
+ if isinstance(texts, str):
116
+ texts = [texts]
117
+
118
+ if self.model_type == "setfit":
119
+ raw_outputs = self.model(texts)
120
+ outputs = np.array(list(raw_outputs), dtype=int)
121
+
122
+ elif self.model_type == "random_forest":
123
+ raw_outputs = self.model.predict(texts)
124
+ outputs = np.array(list(raw_outputs), dtype=int)
125
+
126
+ elif self.model_type == "transformer":
127
+ enc = self.tokenizer(
128
+ texts,
129
+ padding=True,
130
+ truncation=True,
131
+ max_length=self.max_length,
132
+ return_tensors="pt",
133
+ )
134
+ enc = {k: v.to(self.device) for k, v in enc.items()}
135
+
136
+ with torch.no_grad():
137
+ logits = self.model(**enc).logits
138
+ probs = torch.sigmoid(logits)
139
+ preds = (probs > self.threshold).long().cpu().numpy()
140
+
141
+ outputs = preds.astype(int)
142
+ else:
143
+ raise ValueError(f"Unsupported model_type: {self.model_type}")
144
+
145
+ # Ensure 2D shape (n_samples, n_labels)
146
+ if outputs.ndim == 1:
147
+ outputs = outputs.reshape(1, -1)
148
+
149
+ return outputs
requirements.txt ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==1.11.0
2
+ aiohappyeyeballs==2.6.1
3
+ aiohttp==3.13.1
4
+ aiosignal==1.4.0
5
+ alembic==1.17.0
6
+ annotated-doc==0.0.3
7
+ annotated-types==0.7.0
8
+ anyio==4.11.0
9
+ appdirs==1.4.4
10
+ argon2-cffi==25.1.0
11
+ argon2-cffi-bindings==25.1.0
12
+ arrow==1.4.0
13
+ asttokens==3.0.0
14
+ async-lru==2.0.5
15
+ attrs==25.4.0
16
+ babel==2.17.0
17
+ backoff==2.2.1
18
+ beautifulsoup4==4.14.2
19
+ bleach==6.2.0
20
+ blinker==1.9.0
21
+ boto3==1.40.60
22
+ botocore==1.40.60
23
+ cachetools==5.5.2
24
+ certifi==2025.10.5
25
+ cffi==2.0.0
26
+ charset-normalizer==3.4.4
27
+ click==8.3.0
28
+ cloudpickle==3.1.1
29
+ cmake==4.1.2
30
+ comm==0.2.3
31
+ contourpy==1.3.3
32
+ cryptography==46.0.3
33
+ cycler==0.12.1
34
+ dacite==1.6.0
35
+ dagshub==0.6.3
36
+ dagshub-annotation-converter==0.1.15
37
+ databricks-sdk==0.70.0
38
+ dataclasses-json==0.6.7
39
+ datasets==3.6.0
40
+ debugpy==1.8.17
41
+ decorator==5.2.1
42
+ deepchecks[nlp]==0.19.1
43
+ defusedxml==0.7.1
44
+ dill==0.3.8
45
+ docker==7.1.0
46
+ dvc==3.63.0
47
+ dvc-data==3.16.12
48
+ dvc-http==2.32.0
49
+ dvc-objects==5.1.2
50
+ dvc-render==1.0.2
51
+ dvc-s3==3.2.2
52
+ dvc-studio-client==0.22.0
53
+ dvc-task==0.40.2
54
+ evaluate==0.4.6
55
+ executing==2.2.1
56
+ fastapi[standard]==0.120.1
57
+ fastjsonschema==2.21.2
58
+ filelock==3.20.0
59
+ Flask==3.1.2
60
+ flask-cors==6.0.1
61
+ fonttools==4.60.1
62
+ fqdn==1.5.1
63
+ frozenlist==1.8.0
64
+ fsspec==2025.3.0
65
+ ghp-import==2.1.0
66
+ gitdb==4.0.12
67
+ GitPython==3.1.45
68
+ google-auth==2.41.1
69
+ gql==4.0.0
70
+ graphene==3.4.3
71
+ graphql-core==3.2.6
72
+ graphql-relay==3.2.0
73
+ great-expectations==1.9.0
74
+ greenlet==3.2.4
75
+ gunicorn==23.0.0
76
+ h11==0.16.0
77
+ hf-xet==1.2.0
78
+ httpcore==1.0.9
79
+ httpx==0.28.1
80
+ huggingface-hub==0.36.0
81
+ idna==3.11
82
+ importlib_metadata==8.7.0
83
+ iniconfig==2.3.0
84
+ ipykernel==7.1.0
85
+ ipython==9.6.0
86
+ ipython_pygments_lexers==1.1.1
87
+ isoduration==20.11.0
88
+ itsdangerous==2.2.0
89
+ jedi==0.19.2
90
+ Jinja2==3.1.6
91
+ jmespath==1.0.1
92
+ joblib==1.5.2
93
+ json5==0.12.1
94
+ jsonpointer==3.0.0
95
+ jsonschema==4.25.1
96
+ jsonschema-specifications==2025.9.1
97
+ jupyter-events==0.12.0
98
+ jupyter-lsp==2.3.0
99
+ jupyter_client==8.6.3
100
+ jupyter_core==5.9.1
101
+ jupyter_server==2.17.0
102
+ jupyter_server_terminals==0.5.3
103
+ jupyterlab==4.4.10
104
+ jupyterlab_pygments==0.3.0
105
+ jupyterlab_server==2.28.0
106
+ kiwisolver==1.4.9
107
+ lark==1.3.0
108
+ lit==18.1.8
109
+ lxml==6.0.2
110
+ Mako==1.3.10
111
+ Markdown==3.9
112
+ markdown-it-py==4.0.0
113
+ MarkupSafe==3.0.3
114
+ marshmallow==3.26.1
115
+ matplotlib==3.10.7
116
+ matplotlib-inline==0.2.1
117
+ mdurl==0.1.2
118
+ mergedeep==1.3.4
119
+ mistune==3.1.4
120
+ mkdocs==1.6.1
121
+ mkdocs-get-deps==0.2.0
122
+ mlflow==2.22.2
123
+ mlflow-skinny==2.22.2
124
+ mlflow-tracing==3.5.1
125
+ mpmath==1.3.0
126
+ multidict==6.7.0
127
+ multiprocess==0.70.16
128
+ mypy_extensions==1.1.0
129
+ nbclient==0.10.2
130
+ nbconvert==7.16.6
131
+ nbformat==5.10.4
132
+ nest-asyncio==1.6.0
133
+ networkx==3.5
134
+ nltk==3.9.2
135
+ notebook==7.4.7
136
+ notebook_shim==0.2.4
137
+ numpy==2.3.4
138
+ opentelemetry-api==1.38.0
139
+ opentelemetry-proto==1.38.0
140
+ opentelemetry-sdk==1.38.0
141
+ opentelemetry-semantic-conventions==0.59b0
142
+ overrides==7.7.0
143
+ packaging==24.2
144
+ pandas==2.3.3
145
+ pandocfilters==1.5.1
146
+ parso==0.8.5
147
+ pathspec==0.12.1
148
+ pathvalidate==3.3.1
149
+ pexpect==4.9.0
150
+ pillow==12.0.0
151
+ platformdirs==4.5.0
152
+ pluggy==1.6.0
153
+ pre-commit==4.4.0
154
+ prometheus_client==0.23.1
155
+ prompt_toolkit==3.0.52
156
+ propcache==0.4.1
157
+ protobuf==6.33.0
158
+ psutil==7.1.2
159
+ ptyprocess==0.7.0
160
+ pure_eval==0.2.3
161
+ pyarrow==19.0.1
162
+ pyasn1==0.6.1
163
+ pyasn1_modules==0.4.2
164
+ pycparser==2.23
165
+ pydantic==2.12.3
166
+ pydantic_core==2.41.4
167
+ Pygments==2.19.2
168
+ pyparsing==3.2.5
169
+ pytest==8.4.2
170
+ python-dateutil==2.9.0.post0
171
+ python-dotenv==1.2.1
172
+ python-json-logger==4.0.0
173
+ pytz==2025.2
174
+ PyYAML==6.0.3
175
+ pyyaml_env_tag==1.1
176
+ pyzmq==27.1.0
177
+ referencing==0.37.0
178
+ regex==2025.10.23
179
+ requests==2.32.5
180
+ requests-toolbelt==1.0.0
181
+ rfc3339-validator==0.1.4
182
+ rfc3986-validator==0.1.1
183
+ rfc3987-syntax==1.1.0
184
+ rich==14.2.0
185
+ rpds-py==0.28.0
186
+ rsa==4.9.1
187
+ ruff==0.14.2
188
+ s3transfer==0.14.0
189
+ safetensors==0.6.2
190
+ scikit-learn==1.7.2
191
+ scipy==1.16.2
192
+ semver==3.0.4
193
+ Send2Trash==1.8.3
194
+ sentence-transformers==5.1.2
195
+ setfit==1.1.2
196
+ six==1.17.0
197
+ smmap==5.0.2
198
+ sniffio==1.3.1
199
+ soupsieve==2.8
200
+ SQLAlchemy==2.0.44
201
+ sqlparse==0.5.3
202
+ stack-data==0.6.3
203
+ starlette==0.48.0
204
+ sympy==1.14.0
205
+ tenacity==9.1.2
206
+ terminado==0.18.1
207
+ threadpoolctl==3.6.0
208
+ tinycss2==1.4.0
209
+ tokenizers==0.22.1
210
+ torch==2.7.1
211
+ torchaudio==2.7.1
212
+ torchvision==0.22.1
213
+ tornado==6.5.2
214
+ tqdm==4.67.1
215
+ traitlets==5.14.3
216
+ transformers==4.57.1
217
+ treelib==1.8.0
218
+ triton==3.3.1
219
+ typing-inspect==0.9.0
220
+ typing-inspection==0.4.2
221
+ typing_extensions==4.15.0
222
+ tzdata==2025.2
223
+ uri-template==1.3.0
224
+ urllib3==2.5.0
225
+ uvicorn==0.38.0
226
+ watchdog==6.0.0
227
+ wcwidth==0.2.14
228
+ webcolors==24.11.1
229
+ webencodings==0.5.1
230
+ websocket-client==1.9.0
231
+ Werkzeug==3.1.3
232
+ xxhash==3.6.0
233
+ yarl==1.22.0
234
+ zipp==3.23.0