File size: 1,025 Bytes
eb6d478
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# main.py

import gradio as gr
from utils import load_model, predict_category
from news_dataset import NewsDataset  # Importez NewsDataset depuis news_dataset.py

def launch_app():
    dataset = NewsDataset(csv_file="./inshort_news_data.csv", max_length=100)
    num_classes = len(dataset.labels_dict)
    model_path = './models/trained_model1.pth'  # Chemin vers le modèle entraîné
    model = load_model(model_path, num_classes)  # Charger le modèle entraîné avec le bon nombre de classes

    labels_dict = dataset.labels_dict

    def predict_function(headline, article):
        return predict_category(headline, article, model, labels_dict)

    iface = gr.Interface(
        fn=predict_function,
        inputs=["text", "text"],
        outputs="text",
        title="News Category Classification",
        description="Enter a headline and an article to classify its category."
    )

    #iface.launch()
    iface.launch(share=True)


if __name__ == "__main__":
    launch_app()