|
|
import gradio as gr |
|
|
from transformers import AutoTokenizer, DebertaV2Config, DebertaV2Model, PreTrainedModel |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import re |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def clean_teks(text): |
|
|
text = text.lower() |
|
|
text = re.sub(r"http\S+|www\S+|https\S+", "", text) |
|
|
text = re.sub(r"[^a-zA-Z0-9\s]", "", text) |
|
|
text = re.sub(r"\s+", " ", text).strip() |
|
|
return text |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DebertaV3ForMultiTask(PreTrainedModel): |
|
|
config_class = DebertaV2Config |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.num_sentiment_labels = config.num_sentiment_labels |
|
|
self.num_type_labels = config.num_type_labels |
|
|
self.deberta = DebertaV2Model(config) |
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
|
self.sentiment_classifier = nn.Linear(config.hidden_size, self.num_sentiment_labels) |
|
|
self.type_classifier = nn.Linear(config.hidden_size, self.num_type_labels) |
|
|
self.init_weights() |
|
|
|
|
|
def forward(self, input_ids=None, attention_mask=None, **kwargs): |
|
|
outputs = self.deberta(input_ids=input_ids, attention_mask=attention_mask) |
|
|
hidden_state = outputs.last_hidden_state[:, 0] |
|
|
pooled_output = self.dropout(hidden_state) |
|
|
sentiment_logits = self.sentiment_classifier(pooled_output) |
|
|
type_logits = self.type_classifier(pooled_output) |
|
|
return { |
|
|
"sentiment": sentiment_logits, |
|
|
"type": type_logits, |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODEL_PATH = "./finetuned_model_deberta_multitask" |
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) |
|
|
model = DebertaV3ForMultiTask.from_pretrained(MODEL_PATH) |
|
|
model.eval() |
|
|
|
|
|
SENTIMENT_LABELS = ['negative', 'neutral', 'positive'] |
|
|
CATEGORY_LABELS = ['Business', 'Entertainment', 'General', 'Health', 'Science', 'Sports', 'Technology'] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def predict(text): |
|
|
if not text or text.isspace(): |
|
|
return {}, {}, "No input provided", "No input provided" |
|
|
|
|
|
cleaned_text = clean_teks(text) |
|
|
inputs = tokenizer(cleaned_text, return_tensors="pt", truncation=True, max_length=256, padding=True) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
sentiment_logits = outputs["sentiment"] |
|
|
type_logits = outputs["type"] |
|
|
|
|
|
sentiment_probs = F.softmax(sentiment_logits, dim=1)[0] |
|
|
type_probs = F.softmax(type_logits, dim=1)[0] |
|
|
|
|
|
sentiment_confidences = {label: round(prob.item(), 4) for label, prob in zip(SENTIMENT_LABELS, sentiment_probs)} |
|
|
category_confidences = {label: round(prob.item(), 4) for label, prob in zip(CATEGORY_LABELS, type_probs)} |
|
|
|
|
|
best_sentiment = SENTIMENT_LABELS[torch.argmax(sentiment_probs)] |
|
|
best_category = CATEGORY_LABELS[torch.argmax(type_probs)] |
|
|
|
|
|
return sentiment_confidences, category_confidences, f"{best_sentiment} ({sentiment_confidences[best_sentiment]:.2%})", f"{best_category} ({category_confidences[best_category]:.2%})" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown("## 📰 News Sentiment and Category Classification") |
|
|
|
|
|
text_input = gr.Textbox(placeholder="Enter news text here...", label="Input Text", lines=5) |
|
|
submit_button = gr.Button("Analyze", variant="primary") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
gr.Markdown("### 🔎 Predicted Sentiment") |
|
|
sentiment_label = gr.Text(label="Predicted Sentiment") |
|
|
sentiment_output = gr.Label(label="Sentiment Probabilities", num_top_classes=3) |
|
|
with gr.Column(): |
|
|
gr.Markdown("### 🗂️ Predicted News Category") |
|
|
category_label = gr.Text(label="Predicted Category") |
|
|
category_output = gr.Label(label="Category Probabilities", num_top_classes=len(CATEGORY_LABELS)) |
|
|
|
|
|
submit_button.click(fn=predict, inputs=text_input, outputs=[sentiment_output, category_output, sentiment_label, category_label]) |
|
|
|
|
|
gr.Examples( |
|
|
[ |
|
|
["Stanley Kubrick's estate has led the tributes to Shelley Duvall."], |
|
|
["Lignetics Inc. recently acquired the fiber energy products wood pellets business unit from Revelyst."], |
|
|
["An overcrowded California men’s prison was running on emergency generator power for a third day Tuesday."] |
|
|
], |
|
|
inputs=text_input |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|