|
|
from fastapi import APIRouter |
|
|
from datetime import datetime |
|
|
from datasets import load_dataset |
|
|
from sklearn.metrics import accuracy_score |
|
|
|
|
|
from .data.data_loaders import TextDataLoader |
|
|
from .models.text_classifiers import BaselineModel |
|
|
from .utils.evaluation import TextEvaluationRequest |
|
|
from .utils.emissions import get_tracker, clean_emissions_data, get_space_info, EmissionsData |
|
|
|
|
|
|
|
|
from .models.text_classifiers import ModelFactory |
|
|
embedding_ml_model = ModelFactory.create_model({"model_type": "embeddingML"}) |
|
|
|
|
|
distilbert_model = ModelFactory.create_model({"model_type": |
|
|
"distilbert-pretrained", |
|
|
"model_name": |
|
|
"2025-01-27_17-00-47_DistilBERT_Model_fined-tuned_from_distilbert-base-uncased" |
|
|
}) |
|
|
|
|
|
|
|
|
model_to_evaluate = distilbert_model |
|
|
|
|
|
|
|
|
router = APIRouter() |
|
|
DESCRIPTION = model_to_evaluate.description |
|
|
ROUTE = "/text" |
|
|
|
|
|
@router.post(ROUTE, tags=["Text Task"], |
|
|
description=DESCRIPTION) |
|
|
async def evaluate_text(request: TextEvaluationRequest, |
|
|
track_emissions: bool = True, |
|
|
model = distilbert_model, |
|
|
light_dataset: bool = False) -> dict: |
|
|
""" |
|
|
Evaluate text classification for climate disinformation detection. |
|
|
|
|
|
Parameters: |
|
|
----------- |
|
|
request: TextEvaluationRequest |
|
|
The request object containing the dataset configuration. |
|
|
|
|
|
track_emissions: bool |
|
|
Whether to track emissions or not. |
|
|
|
|
|
model: TextClassifier |
|
|
The model to use for inference. |
|
|
|
|
|
light_dataset: bool |
|
|
Whether to use a light dataset or not. |
|
|
|
|
|
Returns: |
|
|
-------- |
|
|
dict |
|
|
A dictionary containing the evaluation results. |
|
|
""" |
|
|
|
|
|
username, space_url = get_space_info() |
|
|
|
|
|
|
|
|
dataset = load_dataset(request.dataset_name, token=os.getenv("HF_TOKEN")) |
|
|
|
|
|
|
|
|
dataset = dataset.map(lambda x: {"label": LABEL_MAPPING[x["label"]]}) |
|
|
|
|
|
|
|
|
test_dataset = dataset["test"] |
|
|
|
|
|
track_emissions = True |
|
|
|
|
|
|
|
|
|
|
|
if track_emissions: |
|
|
tracker = get_tracker() |
|
|
tracker.start() |
|
|
tracker.start_task("inference") |
|
|
|
|
|
|
|
|
predictions = [model.predict(quote) for quote in test_dataset["quote"]] |
|
|
|
|
|
|
|
|
if track_emissions: |
|
|
emissions_data = tracker.stop_task() |
|
|
else: |
|
|
emissions_data = EmissionsData(0, 0) |
|
|
|
|
|
|
|
|
true_labels = test_dataset["label"] |
|
|
accuracy = accuracy_score(true_labels, predictions) |
|
|
|
|
|
|
|
|
results = { |
|
|
"username": username, |
|
|
"space_url": space_url, |
|
|
"submission_timestamp": datetime.now().isoformat(), |
|
|
"model_description": DESCRIPTION, |
|
|
"accuracy": float(accuracy), |
|
|
"energy_consumed_wh": emissions_data.energy_consumed * 1000, |
|
|
"emissions_gco2eq": emissions_data.emissions * 1000, |
|
|
"emissions_data": clean_emissions_data(emissions_data), |
|
|
"api_route": ROUTE, |
|
|
"dataset_config": { |
|
|
"dataset_name": request.dataset_name, |
|
|
"test_size": request.test_size, |
|
|
"test_seed": request.test_seed |
|
|
} |
|
|
} |
|
|
|
|
|
return results |
|
|
|