|
|
"""Pydantic schemas for API request and response validation. |
|
|
|
|
|
This module defines the data models used for validating incoming API requests |
|
|
and structuring outgoing responses. Using Pydantic ensures automatic validation, |
|
|
serialization, and OpenAPI schema generation. |
|
|
|
|
|
Classes: |
|
|
PredictionRequest: Schema for the /predict endpoint request body |
|
|
ModelInfo: Schema for model information responses |
|
|
PredictionResult: Schema for individual prediction results |
|
|
PredictionResponse: Schema for the complete prediction response |
|
|
""" |
|
|
|
|
|
from typing import List, Literal |
|
|
|
|
|
from pydantic import BaseModel, ConfigDict, Field |
|
|
|
|
|
|
|
|
class PredictionRequest(BaseModel): |
|
|
"""Request model for the /predict endpoint. |
|
|
|
|
|
This schema validates the input data for code comment classification requests. |
|
|
All fields are validated automatically by Pydantic, with clear error messages |
|
|
for invalid inputs. |
|
|
|
|
|
Attributes: |
|
|
texts: List of code comments to classify. Must contain at least one item. |
|
|
class_names: List of class names corresponding to each comment. |
|
|
language: The programming language context for classification. |
|
|
Currently supports "java", "python", and "pharo". |
|
|
model_type: The ML model to use for classification. |
|
|
Currently only "catboost" is supported. |
|
|
|
|
|
Example: |
|
|
```python |
|
|
request = PredictionRequest( |
|
|
texts=["This method calculates fibonacci", "TODO: Fix bug"], |
|
|
class_names=["MathUtils", "Calculator"], |
|
|
language="python", |
|
|
model_type="catboost" |
|
|
) |
|
|
``` |
|
|
""" |
|
|
|
|
|
texts: List[str] = Field( |
|
|
..., |
|
|
description="List of code comments to classify", |
|
|
min_length=1, |
|
|
json_schema_extra={"example": ["This method calculates fibonacci"]}, |
|
|
) |
|
|
class_names: List[str] = Field( |
|
|
..., |
|
|
description="List of class names corresponding to each comment", |
|
|
min_length=1, |
|
|
json_schema_extra={"example": ["MathUtils"]}, |
|
|
) |
|
|
language: Literal["java", "python", "pharo"] = Field( |
|
|
..., description="Programming language context (java, python, or pharo)" |
|
|
) |
|
|
model_type: Literal["catboost"] = Field( |
|
|
default="catboost", description="ML model type to use for classification" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
model_config = ConfigDict( |
|
|
extra="forbid", |
|
|
json_schema_extra={ |
|
|
"example": { |
|
|
"texts": ["This method calculates the fibonacci sequence", "TODO: Fix this bug"], |
|
|
"class_names": ["MathUtils", "Calculator"], |
|
|
"language": "python", |
|
|
"model_type": "catboost", |
|
|
} |
|
|
}, |
|
|
) |
|
|
|
|
|
|
|
|
class ModelInfo(BaseModel): |
|
|
"""Schema for model availability information. |
|
|
|
|
|
Used in responses to describe which models are available for a given language. |
|
|
|
|
|
Attributes: |
|
|
language: The programming language identifier |
|
|
available_types: List of available model types for this language |
|
|
""" |
|
|
|
|
|
language: str = Field(..., description="Programming language identifier") |
|
|
available_types: List[str] = Field( |
|
|
..., description="List of available model types (e.g., ['catboost'])" |
|
|
) |
|
|
|
|
|
|
|
|
class PredictionResult(BaseModel): |
|
|
"""Schema for a single prediction result. |
|
|
|
|
|
Represents the classification result for one code comment. |
|
|
|
|
|
Attributes: |
|
|
text: The original input text that was classified |
|
|
class_name: The class name corresponding to the input text |
|
|
labels: List of predicted category labels for the comment |
|
|
""" |
|
|
|
|
|
text: str = Field(..., description="The original input comment") |
|
|
class_name: str = Field(..., description="The class name corresponding to the input comment") |
|
|
labels: List[str] = Field(..., description="List of predicted category labels") |
|
|
|
|
|
|
|
|
class PredictionResponse(BaseModel): |
|
|
"""Schema for the complete prediction response data. |
|
|
|
|
|
Contains the full response payload from a prediction request. |
|
|
|
|
|
Attributes: |
|
|
model_used: The model type that was used for prediction |
|
|
language: The language context used for classification |
|
|
results: List of prediction results for each input text |
|
|
""" |
|
|
|
|
|
model_used: str = Field(..., description="Model type used for prediction") |
|
|
language: str = Field(..., description="Language context for classification") |
|
|
results: List[PredictionResult] = Field(..., description="List of classification results") |
|
|
|