LaurianeMD commited on
Commit
eb6d478
·
verified ·
1 Parent(s): 80c952e

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -1,35 +1,2 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ models/trained_model.pth filter=lfs diff=lfs merge=lfs -text
2
+ *.pth filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.gitignore ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110
+ .pdm.toml
111
+ .pdm-python
112
+ .pdm-build/
113
+
114
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115
+ __pypackages__/
116
+
117
+ # Celery stuff
118
+ celerybeat-schedule
119
+ celerybeat.pid
120
+
121
+ # SageMath parsed files
122
+ *.sage.py
123
+
124
+ # Environments
125
+ .env
126
+ .venv
127
+ env/
128
+ venv/
129
+ ENV/
130
+ env.bak/
131
+ venv.bak/
132
+
133
+ # Spyder project settings
134
+ .spyderproject
135
+ .spyproject
136
+
137
+ # Rope project settings
138
+ .ropeproject
139
+
140
+ # mkdocs documentation
141
+ /site
142
+
143
+ # mypy
144
+ .mypy_cache/
145
+ .dmypy.json
146
+ dmypy.json
147
+
148
+ # Pyre type checker
149
+ .pyre/
150
+
151
+ # pytype static type analyzer
152
+ .pytype/
153
+
154
+ # Cython debug symbols
155
+ cython_debug/
156
+
157
+ # PyCharm
158
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
161
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162
+ #.idea/
LICENSE ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ BSD 3-Clause License
2
+
3
+ Copyright (c) 2024, Lauriane MBAGDJE DORENAN
4
+
5
+ Redistribution and use in source and binary forms, with or without
6
+ modification, are permitted provided that the following conditions are met:
7
+
8
+ 1. Redistributions of source code must retain the above copyright notice, this
9
+ list of conditions and the following disclaimer.
10
+
11
+ 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ this list of conditions and the following disclaimer in the documentation
13
+ and/or other materials provided with the distribution.
14
+
15
+ 3. Neither the name of the copyright holder nor the names of its
16
+ contributors may be used to endorse or promote products derived from
17
+ this software without specific prior written permission.
18
+
19
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
README.md CHANGED
@@ -1,12 +1,114 @@
1
- ---
2
- title: News Article Classification Bert
3
- emoji: 👀
4
- colorFrom: pink
5
- colorTo: green
6
- sdk: gradio
7
- sdk_version: 4.37.2
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: News_article_classification_bert
3
+ app_file: main.py
4
+ sdk: gradio
5
+ sdk_version: 4.37.2
6
+ ---
7
+ # news_classification
8
+ News Article Classification: Combining Headlines and Articles to Categorize News
9
+
10
+ # **News Classification Using BERT**
11
+ This project utilizes BERT (Bidirectional Encoder Representations from Transformers) for classifying news articles into predefined categories. The model achieves an accuracy of 96% and a loss of 0.1 on the test dataset.
12
+
13
+ ## **Dataset**
14
+ The dataset used in this project is inshort_news_data.csv, containing short news articles categorized into various topics.
15
+
16
+ ## **Model Architecture**
17
+ The model architecture is based on a custom BERT model fine-tuned for sequence classification:
18
+
19
+ BERT Model: bert-base-uncased
20
+ Batch Size: 8
21
+ Optimizer: Adam with learning rate 2e-5
22
+ Loss Function: CrossEntropyLoss
23
+ Training
24
+ The model is trained for 3 epochs with the following steps:
25
+
26
+ **Data Preparation:** The dataset is tokenized using the BERT tokenizer and prepared as PyTorch DataLoader objects.
27
+
28
+ **Training:** The model is trained using stochastic gradient descent with backpropagation. During training, the loss is minimized and weights are updated iteratively.
29
+
30
+ **Evaluation:** After each epoch, the model is evaluated on a held-out validation set to measure accuracy and loss.
31
+
32
+ **Results**
33
+ Accuracy: 96%
34
+ Loss: 0.1
35
+ Usage
36
+ To use the trained model for inference:
37
+
38
+ Ensure all dependencies are installed (transformers, torch, fastapi, pydantic, etc.).
39
+ Load the model using torch.load() and the appropriate tokenizer.
40
+ Send POST requests to /predict/ endpoint with JSON containing headline and article fields to classify news articles.
41
+ How to Run
42
+ To run the FastAPI application:
43
+ uvicorn api:app --host localhost --port 8080
44
+
45
+ Navigate to http://localhost:8080/docs to interact with the API using Swagger UI.
46
+
47
+ ---------------------------------------------------------------------------------------------------
48
+ ***french***
49
+ # Classification des Catégories de News avec BERT
50
+
51
+ Ce projet vise à classifier automatiquement les catégories de nouvelles à partir des titres et du contenu des articles en utilisant un modèle BERT préalablement entraîné.
52
+
53
+ ## Contenu du Projet
54
+
55
+ - `bert_classification.py` : Contient la définition du modèle `CustomBert` utilisé pour la classification.
56
+ - `news_dataset.py` : Implémente la classe `NewsDataset` pour charger et prétraiter le dataset de nouvelles.
57
+ - `utils.py` : Fournit des fonctions utilitaires pour charger le modèle entraîné et effectuer des prédictions.
58
+ - `main.py` : charge un modèle pré-entraîné pour la classification des catégories de nouvelles, crée une interface utilisateur web avec Gradio
59
+ pour permettre aux utilisateurs de soumettre des titres et des articles, et affiche la catégorie prédite pour ces nouvelles.
60
+ - `api.py` : Implémente une API web à l'aide de FastAPI pour permettre la prédiction des catégories de nouvelles en temps réel.
61
+
62
+ ## Installation des Dépendances
63
+
64
+ Assurez-vous d'avoir Python 3.7+ installé ainsi que les packages nécessaires :
65
+
66
+ pip install -r requirements.txt
67
+
68
+ ## Entraînement du Modèle
69
+ Pour entraîner le modèle, exécutez main.py. Assurez-vous d'avoir un fichier CSV inshort_news_data.csv avec les colonnes news_headline et news_article.
70
+
71
+ python main.py
72
+
73
+
74
+ ## Détails de l'Entraînement
75
+
76
+ Batch Size : 8 (par défaut)
77
+ Epochs : 3 (par défaut)
78
+ Précision : 96%, Perte : 0.1 après l'entraînement.
79
+ Modèle sauvegardé à ./models/trained_model1.pth.
80
+
81
+ ## Utilisation de l'API Web
82
+ Pour utiliser l'API web pour la prédiction des catégories de news :
83
+
84
+ Lancez l'API avec FastAPI en exécutant api.py:
85
+
86
+ uvicorn api:app --host localhost --port 8080
87
+
88
+ Accédez à http://localhost:8080 dans votre navigateur pour vérifier que l'API est en ligne.
89
+ Envoyez des requêtes POST à [http://localhost:8080/predict/](http://localhost:8080/docs#/default/prediction_predict__post) avec les données d'entrée requises pour obtenir des prédictions de catégories de news.
90
+ Exemple de requête JSON pour la prédiction :
91
+
92
+ json
93
+
94
+ {
95
+ "headline": "50-year-old problem of biology solved by Artificial Intelligence",
96
+ "article": "DeepMind's AI system 'AlphaFold' has been recognised as a solution to \"protein folding\", a grand challenge in biology for over 50 years. DeepMind showed it can predict how proteins fold into 3D shapes, a complex process that is fundamental to understanding the biological machinery of life. AlphaFold can predict the shape of proteins within the width of an atom."
97
+ }
98
+ Exemple de réponse attendue :
99
+
100
+ json
101
+
102
+ {
103
+ "category": "Science",
104
+ "score": 94.23
105
+ }
106
+ Assurez-vous d'avoir une connexion Internet active lors de l'exécution de l'API pour permettre le chargement du tokenizer BERT.
107
+
108
+
109
+
110
+
111
+
112
+
113
+
114
+
accuracy_and_loss.PNG ADDED
api.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import torch
2
+ # from fastapi import FastAPI, HTTPException
3
+ # from pydantic import BaseModel
4
+ # from transformers import AutoTokenizer
5
+ # from torch.utils.data import DataLoader
6
+ # from news_dataset import NewsDataset
7
+ # from utils import load_model, predict_category
8
+
9
+ # # Initialize FastAPI app
10
+ # app = FastAPI()
11
+
12
+ # # Load dataset and model
13
+ # dataset = NewsDataset(csv_file="./inshort_news_data.csv", max_length=100)
14
+ # num_classes = len(dataset.labels_dict)
15
+ # model_path = './models/trained_model.pth' # Path to your trained model
16
+ # model = load_model(model_path, num_classes)
17
+ # labels_dict = dataset.labels_dict
18
+
19
+ # # Tokenizer initialization
20
+ # tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
21
+
22
+ # # Define Pydantic model for input data
23
+ # class RequestPost(BaseModel):
24
+ # headline: str
25
+ # article: str
26
+
27
+ # @app.get("/")
28
+ # def read_root():
29
+ # return {"Hello": "World"}
30
+
31
+ # # Define endpoint for prediction
32
+ # @app.post("/predict/")
33
+ # def prediction(request: RequestPost):
34
+ # try:
35
+ # category, score = predict_category(request.headline, request.article, model, labels_dict)
36
+ # return {"category": category, "score": score}
37
+ # except Exception as e:
38
+ # raise HTTPException(status_code=500, detail=str(e))
39
+
40
+
41
+ from fastapi import FastAPI, HTTPException
42
+ from pydantic import BaseModel
43
+ from typing import List, Optional
44
+ from transformers import AutoTokenizer
45
+ from torch.utils.data import DataLoader
46
+ from news_dataset import NewsDataset
47
+ from utils import load_model, predict_category
48
+
49
+ # Initialize FastAPI app
50
+ app = FastAPI()
51
+
52
+ # Load dataset and model
53
+ dataset = NewsDataset(csv_file="./inshort_news_data.csv", max_length=100)
54
+ num_classes = len(dataset.labels_dict)
55
+ model_path = './models/trained_model1.pth' # Path to your trained model
56
+ model = load_model(model_path, num_classes)
57
+ labels_dict = dataset.labels_dict
58
+
59
+ # Tokenizer initialization
60
+ tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
61
+
62
+ # Define Pydantic model for input data
63
+ class RequestPost(BaseModel):
64
+ headline: str
65
+ article: str
66
+
67
+ @app.get("/")
68
+ def read_root():
69
+ return {"Hello": "World"}
70
+
71
+ # Define endpoint for prediction
72
+ @app.post("/predict/")
73
+ def prediction(request: RequestPost):
74
+ try:
75
+ category, score = predict_category(request.headline, request.article, model, labels_dict)
76
+ return {"category": category, "score": score}
77
+ except Exception as e:
78
+ raise HTTPException(status_code=500, detail=str(e))
bert_classification.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset, DataLoader
3
+ import torch.nn as nn
4
+ from transformers import AutoTokenizer, BertModel
5
+ from sklearn.model_selection import train_test_split
6
+ from tqdm import tqdm
7
+ import numpy as np
8
+ import pandas as pd
9
+ from news_dataset import NewsDataset
10
+
11
+ class CustomBert(nn.Module):
12
+ def __init__(self, model_name_or_path="bert-base-uncased", n_classes=2):
13
+ super(CustomBert, self).__init__()
14
+ self.bert_pretrained = BertModel.from_pretrained(model_name_or_path)
15
+ self.classifier = nn.Linear(self.bert_pretrained.config.hidden_size, n_classes)
16
+
17
+ def forward(self, input_ids, attention_mask):
18
+ x = self.bert_pretrained(input_ids=input_ids, attention_mask=attention_mask)
19
+ x = self.classifier(x.pooler_output)
20
+ return x
21
+
22
+ #Training function
23
+ def training_step(model, data_loader, loss_fn, optimizer):
24
+ model.train()
25
+ total_loss = 0
26
+
27
+ for data in tqdm(data_loader, total=len(data_loader)):
28
+ input_ids = data['input_ids']
29
+ attention_mask = data['attention_mask']
30
+ labels = data['labels']
31
+
32
+ output = model(input_ids=input_ids, attention_mask=attention_mask)
33
+ loss = loss_fn(output, labels)
34
+
35
+ loss.backward()
36
+ optimizer.step()
37
+ optimizer.zero_grad()
38
+
39
+ total_loss += loss.item()
40
+
41
+ return total_loss / len(data_loader.dataset)
42
+
43
+ #Evaluation
44
+ def evaluation(model, test_dataloader, loss_fn):
45
+ model.eval()
46
+ correct_predictions = 0
47
+ losses = []
48
+
49
+ for data in tqdm(test_dataloader, total=len(test_dataloader)):
50
+ input_ids = data['input_ids']
51
+ attention_mask = data['attention_mask']
52
+ labels = data['labels']
53
+
54
+ output = model(input_ids=input_ids, attention_mask=attention_mask)
55
+ _, pred = output.max(1)
56
+ correct_predictions += torch.sum(pred == labels)
57
+
58
+ loss = loss_fn(output, labels)
59
+ losses.append(loss.item())
60
+
61
+ return correct_predictions.double() / len(test_dataloader.dataset), np.mean(losses)
62
+
63
+
64
+ #main
65
+ if __name__ == "__main__":
66
+ dataset = NewsDataset(csv_file="./inshort_news_data.csv", max_length=100)
67
+ num_classes = len(dataset.labels_dict)
68
+
69
+ train_data, test_data = train_test_split(dataset, test_size=0.2)
70
+
71
+ train_dataloader = DataLoader(train_data, batch_size=8, shuffle=True)
72
+ test_dataloader = DataLoader(test_data, batch_size=8, shuffle=False)
73
+
74
+ model = CustomBert(n_classes=num_classes)
75
+ loss_fn = nn.CrossEntropyLoss()
76
+ optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)
77
+
78
+ num_epochs = 3
79
+ for epoch in range(num_epochs):
80
+ print(f"Epoch {epoch + 1}/{num_epochs}")
81
+ train_loss = training_step(model, train_dataloader, loss_fn, optimizer)
82
+ print(f"Train Loss: {train_loss:.4f}")
83
+
84
+ val_acc, val_loss = evaluation(model, test_dataloader, loss_fn)
85
+ print(f"Validation Accuracy: {val_acc:.4f}, Validation Loss: {val_loss:.4f}")
86
+
87
+ # Save the model
88
+ import os
89
+ os.makedirs('./models', exist_ok=True)
90
+
91
+ torch.save(model.state_dict(), './models/trained_model1.pth')
92
+
inshort_news_data.csv ADDED
The diff for this file is too large to render. See raw diff
 
main.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # main.py
2
+
3
+ import gradio as gr
4
+ from utils import load_model, predict_category
5
+ from news_dataset import NewsDataset # Importez NewsDataset depuis news_dataset.py
6
+
7
+ def launch_app():
8
+ dataset = NewsDataset(csv_file="./inshort_news_data.csv", max_length=100)
9
+ num_classes = len(dataset.labels_dict)
10
+ model_path = './models/trained_model1.pth' # Chemin vers le modèle entraîné
11
+ model = load_model(model_path, num_classes) # Charger le modèle entraîné avec le bon nombre de classes
12
+
13
+ labels_dict = dataset.labels_dict
14
+
15
+ def predict_function(headline, article):
16
+ return predict_category(headline, article, model, labels_dict)
17
+
18
+ iface = gr.Interface(
19
+ fn=predict_function,
20
+ inputs=["text", "text"],
21
+ outputs="text",
22
+ title="News Category Classification",
23
+ description="Enter a headline and an article to classify its category."
24
+ )
25
+
26
+ #iface.launch()
27
+ iface.launch(share=True)
28
+
29
+
30
+ if __name__ == "__main__":
31
+ launch_app()
models/trained_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7949b4c99b6c2a8021bfd95d80a1fcf6567f71b7dd84a0984b80e58d94d75c36
3
+ size 438039157
models/trained_model1.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3f06afe32b4012087998ccf1edbb475dc2f84c43600f61d1b4d1f9c5af1b690d
3
+ size 438039361
news_dataset.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # news_dataset.py
2
+
3
+ import torch
4
+ from torch.utils.data import Dataset
5
+ from transformers import AutoTokenizer
6
+
7
+ class NewsDataset(Dataset):
8
+ def __init__(self, csv_file, max_length):
9
+ import pandas as pd
10
+ self.df = pd.read_csv(csv_file)
11
+ self.labels = self.df['news_category'].unique()
12
+ self.labels_dict = {label: index for index, label in enumerate(self.labels)}
13
+
14
+ self.df['news_category'] = self.df['news_category'].map(self.labels_dict)
15
+ self.max_length = max_length
16
+ self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
17
+
18
+ def __len__(self):
19
+ return len(self.df)
20
+
21
+ def __getitem__(self, index):
22
+ headline_text = self.df.news_headline[index]
23
+ article_text = self.df.news_article[index]
24
+ combined_text = headline_text + " " + article_text
25
+ label = self.df.news_category[index]
26
+
27
+ inputs = self.tokenizer(
28
+ combined_text,
29
+ padding="max_length",
30
+ max_length=self.max_length,
31
+ truncation=True,
32
+ return_tensors="pt"
33
+ )
34
+
35
+ labels = torch.tensor(label)
36
+
37
+ return {
38
+ "input_ids": inputs["input_ids"].squeeze(0),
39
+ "attention_mask": inputs["attention_mask"].squeeze(0),
40
+ "labels": labels,
41
+ }
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.0.0
2
+ transformers==4.30.0
3
+ scikit-learn==1.2.2
4
+ pandas==1.5.3
5
+ tqdm==4.65.0
6
+ numpy==1.23.5
7
+ gradio==3.4.1
8
+ fastapi
9
+ #"uvicorn[standard]"
10
+ pydantic
utils.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer
3
+ import torch.nn as nn
4
+ from bert_classification import CustomBert # Importer le modèle depuis le fichier bert_classification.py
5
+
6
+ tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
7
+
8
+ def load_model(model_path, num_classes):
9
+ model = CustomBert(n_classes=num_classes) # Adapter ici le nombre de classes
10
+ model.load_state_dict(torch.load(model_path))
11
+ model.eval()
12
+ return model
13
+
14
+ def predict_category(headline, article, model, labels_dict, max_length=100):
15
+ text = headline + " " + article
16
+ inputs = tokenizer(
17
+ text,
18
+ padding="max_length",
19
+ max_length=max_length,
20
+ truncation=True,
21
+ return_tensors="pt"
22
+ )
23
+
24
+ with torch.no_grad():
25
+ output = model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])
26
+ probabilities = nn.Softmax(dim=1)(output)
27
+ _, pred = torch.max(probabilities, dim=1)
28
+ score = probabilities[0][pred].item()
29
+
30
+ inv_labels_dict = {v: k for k, v in labels_dict.items()}
31
+ category = inv_labels_dict[pred.item()]
32
+
33
+ score = round(score, 2)
34
+
35
+ return category, score