Spaces:
Runtime error
Runtime error
File size: 1,025 Bytes
eb6d478 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
# main.py
import gradio as gr
from utils import load_model, predict_category
from news_dataset import NewsDataset # Importez NewsDataset depuis news_dataset.py
def launch_app():
dataset = NewsDataset(csv_file="./inshort_news_data.csv", max_length=100)
num_classes = len(dataset.labels_dict)
model_path = './models/trained_model1.pth' # Chemin vers le modèle entraîné
model = load_model(model_path, num_classes) # Charger le modèle entraîné avec le bon nombre de classes
labels_dict = dataset.labels_dict
def predict_function(headline, article):
return predict_category(headline, article, model, labels_dict)
iface = gr.Interface(
fn=predict_function,
inputs=["text", "text"],
outputs="text",
title="News Category Classification",
description="Enter a headline and an article to classify its category."
)
#iface.launch()
iface.launch(share=True)
if __name__ == "__main__":
launch_app()
|