File size: 6,277 Bytes
b01a119 5f47787 b01a119 5f47787 b01a119 5f47787 5442884 5f47787 5442884 5f47787 5442884 5f47787 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
import gradio as gr
import torch
import torch.nn as nn
import pickle
import pandas as pd
from transformers import RobertaTokenizerFast, RobertaModel
# Load label mappings
with open("label_mappings.pkl", "rb") as f:
label_mappings = pickle.load(f)
label_to_team = label_mappings.get("label_to_team", {})
label_to_email = label_mappings.get("label_to_email", {})
# Load the tokenizer
tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
# Define RoBERTa Model
class RoBertaClassifier(nn.Module):
def __init__(self, num_teams, num_emails):
super(RoBertaClassifier, self).__init__()
self.roberta = RobertaModel.from_pretrained("roberta-base")
self.team_classifier = nn.Linear(self.roberta.config.hidden_size, num_teams)
self.email_classifier = nn.Linear(self.roberta.config.hidden_size, num_emails)
def forward(self, input_ids, attention_mask):
outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask)
cls_output = outputs.last_hidden_state[:, 0, :]
team_logits = self.team_classifier(cls_output)
email_logits = self.email_classifier(cls_output)
return team_logits, email_logits
# Load Model
num_teams = len(label_to_team)
num_emails = len(label_to_email)
model = RoBertaClassifier(num_teams, num_emails)
checkpoint = torch.load("ticket_classification_model.pth", map_location=torch.device("cpu"))
filtered_checkpoint = {k: v for k, v in checkpoint.items() if k in model.state_dict()}
model.load_state_dict(filtered_checkpoint, strict=False)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)
model.eval()
# Prediction Function
def predict_tickets(ticket_descriptions):
predictions = []
csv_data = []
for idx, description in enumerate(ticket_descriptions, start=1):
inputs = tokenizer(description, return_tensors="pt", truncation=True, padding="max_length", max_length=128).to(device)
with torch.no_grad():
team_logits, email_logits = model(inputs.input_ids, inputs.attention_mask)
predicted_team_index = team_logits.argmax(dim=-1).cpu().item()
predicted_email_index = email_logits.argmax(dim=-1).cpu().item()
predicted_team = label_to_team.get(predicted_team_index, "Unknown Team")
predicted_email = label_to_email.get(predicted_email_index, "Unknown Email")
predictions.append(f"**{idx}. {description}**\n - **Assigned Team:** {predicted_team}\n - **Team Email:** {predicted_email}\n")
csv_data.append([idx, description, predicted_team, predicted_email])
df = pd.DataFrame(csv_data, columns=["Index", "Description", "Assigned Team", "Team Email"])
csv_file = "ticket-predictions.csv"
df.to_csv(csv_file, index=False)
return "\n".join(predictions), csv_file
# Gradio Functions
def gradio_predict(option, text_input, file_input):
if option == "Enter Text":
descriptions = text_input.split("\n")
descriptions = [desc.strip() for desc in descriptions if desc.strip()]
elif option == "Upload CSV" and file_input is not None:
df = pd.read_csv(file_input)
if "Description" not in df.columns:
return "โ ๏ธ Error: CSV must contain a 'Description' column.", None
descriptions = df["Description"].tolist()
else:
return "โ ๏ธ Please provide input.", None
results, csv_file = predict_tickets(descriptions)
return results, csv_file
def clear_inputs():
return "Enter Text", "", None, "", None
# Custom CSS to Fix UI Scaling Issues
custom_css = """
.gradio-container { max-width: 1100px !important; margin: auto !important; }
#title { text-align: center; font-size: 24px !important; font-weight: bold; }
#predict-button, #clear-button { width: 100% !important; height: 50px !important; font-size: 18px !important; }
#results-box { height: 320px !important; overflow-y: auto !important; background: #f9f9f9; padding: 10px; border-radius: 8px; }
"""
# Gradio App UI
with gr.Blocks(css=".gradio-container {max-width: 1100px; margin: auto;}") as app:
gr.Markdown(
"""
# Multi-Ticket AI Classification System
**Supports:** Multi-line text input & CSV upload.
**Output:** Text results & downloadable CSV file.
**Model:** Fine-tuned **RoBERTa** for classification.
Enter ticket Description/Comment/Summary or upload a **CSV file** to predict Assigned Teams & Team Emails.
""",
elem_id="title"
)
with gr.Row():
with gr.Column(scale=1.5):
option = gr.Radio(["Enter Text", "Upload CSV"], label="๐ Choose Input Method", value="Enter Text")
text_input = gr.Textbox(
label="Enter Ticket Description/Comment/Summary (One per line)",
lines=6,
placeholder="Example:\n - Database performance issue\n - Login fails for admin users..."
)
file_input = gr.File(label="๐ Upload CSV (Optional)", type="filepath", visible=False)
with gr.Column(scale=1.5):
gr.Markdown("## Prediction Results") # **Title for Prediction Results**
results_output = gr.Markdown(elem_id="results-box", visible=True)
download_csv = gr.File(label="๐ฅ Download Predictions CSV", interactive=False)
with gr.Row():
predict_btn = gr.Button("PREDICT", variant="primary")
clear_btn = gr.Button("CLEAR", variant="secondary")
# Logic for Showing/ Hiding Input Fields
def toggle_input(selected_option):
return gr.update(visible=(selected_option == "Enter Text")), gr.update(visible=(selected_option == "Upload CSV"))
option.change(fn=toggle_input, inputs=[option], outputs=[text_input, file_input])
predict_btn.click(fn=gradio_predict, inputs=[option, text_input, file_input], outputs=[results_output, download_csv])
clear_btn.click(fn=clear_inputs, inputs=[], outputs=[option, text_input, file_input, results_output, download_csv])
# Footer view
gr.Markdown("---")
gr.Markdown(
"<p style='text-align: center;color: gray;'>"
"Developed by NYP student @ Min Thein Win: Student ID: 3907578Y</p>"
)
# Launch App
app.launch(share=True)
|