Spaces:
Sleeping
Sleeping
update content with the text model from Thomas repository https://huggingface.co/spaces/tombou/frugal-ai-challenge
42b7ac6
| from fastapi import APIRouter | |
| from datetime import datetime | |
| from datasets import load_dataset | |
| from sklearn.metrics import accuracy_score | |
| from .data.data_loaders import TextDataLoader | |
| from .models.text_classifiers import BaselineModel | |
| from .utils.evaluation import TextEvaluationRequest | |
| from .utils.emissions import get_tracker, clean_emissions_data, get_space_info, EmissionsData | |
| # define models | |
| from .models.text_classifiers import ModelFactory | |
| embedding_ml_model = ModelFactory.create_model({"model_type": "embeddingML"}) | |
| distilbert_model = ModelFactory.create_model({"model_type": | |
| "distilbert-pretrained", | |
| "model_name": | |
| "2025-01-27_17-00-47_DistilBERT_Model_fined-tuned_from_distilbert-base-uncased" | |
| }) | |
| model_to_evaluate = distilbert_model | |
| # define router | |
| router = APIRouter() | |
| DESCRIPTION = model_to_evaluate.description | |
| ROUTE = "/text" | |
| async def evaluate_text(request: TextEvaluationRequest, | |
| track_emissions: bool = True, | |
| model = distilbert_model, | |
| light_dataset: bool = False) -> dict: | |
| """ | |
| Evaluate text classification for climate disinformation detection. | |
| Parameters: | |
| ----------- | |
| request: TextEvaluationRequest | |
| The request object containing the dataset configuration. | |
| track_emissions: bool | |
| Whether to track emissions or not. | |
| model: TextClassifier | |
| The model to use for inference. | |
| light_dataset: bool | |
| Whether to use a light dataset or not. | |
| Returns: | |
| -------- | |
| dict | |
| A dictionary containing the evaluation results. | |
| """ | |
| # Get space info | |
| username, space_url = get_space_info() | |
| # Load the dataset | |
| test_dataset = TextDataLoader(request, light=light_dataset).get_test_dataset() | |
| # Start tracking emissions | |
| if track_emissions: | |
| tracker = get_tracker() | |
| tracker.start() | |
| tracker.start_task("inference") | |
| # model inference | |
| predictions = [model.predict(quote) for quote in test_dataset["quote"]] | |
| # Stop tracking emissions | |
| if track_emissions: | |
| emissions_data = tracker.stop_task() | |
| else: | |
| emissions_data = EmissionsData(0, 0) | |
| # Calculate accuracy | |
| true_labels = test_dataset["label"] | |
| accuracy = accuracy_score(true_labels, predictions) | |
| # Prepare results dictionary | |
| results = { | |
| "username": username, | |
| "space_url": space_url, | |
| "submission_timestamp": datetime.now().isoformat(), | |
| "model_description": DESCRIPTION, | |
| "accuracy": float(accuracy), | |
| "energy_consumed_wh": emissions_data.energy_consumed * 1000, | |
| "emissions_gco2eq": emissions_data.emissions * 1000, | |
| "emissions_data": clean_emissions_data(emissions_data), | |
| "api_route": ROUTE, | |
| "dataset_config": { | |
| "dataset_name": request.dataset_name, | |
| "test_size": request.test_size, | |
| "test_seed": request.test_seed | |
| } | |
| } | |
| return results | |