| from transformers import pipeline | |
| import gradio as gr | |
| model_checkpoint = "MuntasirHossain/distilbert-finetuned-ag-news" | |
| model = pipeline("text-classification", model=model_checkpoint) | |
| def predict(prompt): | |
| completion = model(prompt)[0]["label"] | |
| return completion | |
| description = "This AI model is trained to classify news articles into four categories: World, Sports, Business and Science/Tech." | |
| title = "Classify Your Articles" | |
| theme = "peach" | |
| examples=[["Global Retail Giants Gear Up for Record-Breaking Holiday Sales Season Amidst Supply Chain Challenges and Rising Consumer Demand."]] | |
| gr.Interface(fn=predict, | |
| inputs="textbox", | |
| outputs="text", | |
| title=title, | |
| theme = theme, | |
| description=description, | |
| examples=examples, | |
| ).launch() |