|
|
import gradio as gr |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import pickle |
|
|
import pandas as pd |
|
|
from transformers import RobertaTokenizerFast, RobertaModel |
|
|
|
|
|
|
|
|
with open("label_mappings.pkl", "rb") as f: |
|
|
label_mappings = pickle.load(f) |
|
|
|
|
|
label_to_team = label_mappings.get("label_to_team", {}) |
|
|
label_to_email = label_mappings.get("label_to_email", {}) |
|
|
|
|
|
|
|
|
tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base") |
|
|
|
|
|
|
|
|
class RoBertaClassifier(nn.Module): |
|
|
def __init__(self, num_teams, num_emails): |
|
|
super(RoBertaClassifier, self).__init__() |
|
|
self.roberta = RobertaModel.from_pretrained("roberta-base") |
|
|
self.team_classifier = nn.Linear(self.roberta.config.hidden_size, num_teams) |
|
|
self.email_classifier = nn.Linear(self.roberta.config.hidden_size, num_emails) |
|
|
|
|
|
def forward(self, input_ids, attention_mask): |
|
|
outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask) |
|
|
cls_output = outputs.last_hidden_state[:, 0, :] |
|
|
|
|
|
team_logits = self.team_classifier(cls_output) |
|
|
email_logits = self.email_classifier(cls_output) |
|
|
|
|
|
return team_logits, email_logits |
|
|
|
|
|
|
|
|
num_teams = len(label_to_team) |
|
|
num_emails = len(label_to_email) |
|
|
model = RoBertaClassifier(num_teams, num_emails) |
|
|
checkpoint = torch.load("ticket_classification_model.pth", map_location=torch.device("cpu")) |
|
|
filtered_checkpoint = {k: v for k, v in checkpoint.items() if k in model.state_dict()} |
|
|
model.load_state_dict(filtered_checkpoint, strict=False) |
|
|
|
|
|
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
|
|
model.to(device) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
def predict_tickets(ticket_descriptions): |
|
|
predictions = [] |
|
|
csv_data = [] |
|
|
for idx, description in enumerate(ticket_descriptions, start=1): |
|
|
inputs = tokenizer(description, return_tensors="pt", truncation=True, padding="max_length", max_length=128).to(device) |
|
|
with torch.no_grad(): |
|
|
team_logits, email_logits = model(inputs.input_ids, inputs.attention_mask) |
|
|
predicted_team_index = team_logits.argmax(dim=-1).cpu().item() |
|
|
predicted_email_index = email_logits.argmax(dim=-1).cpu().item() |
|
|
predicted_team = label_to_team.get(predicted_team_index, "Unknown Team") |
|
|
predicted_email = label_to_email.get(predicted_email_index, "Unknown Email") |
|
|
predictions.append(f"**{idx}. {description}**\n - **Assigned Team:** {predicted_team}\n - **Team Email:** {predicted_email}\n") |
|
|
csv_data.append([idx, description, predicted_team, predicted_email]) |
|
|
|
|
|
df = pd.DataFrame(csv_data, columns=["Index", "Description", "Assigned Team", "Team Email"]) |
|
|
csv_file = "ticket-predictions.csv" |
|
|
df.to_csv(csv_file, index=False) |
|
|
return "\n".join(predictions), csv_file |
|
|
|
|
|
|
|
|
def gradio_predict(option, text_input, file_input): |
|
|
if option == "Enter Text": |
|
|
descriptions = text_input.split("\n") |
|
|
descriptions = [desc.strip() for desc in descriptions if desc.strip()] |
|
|
elif option == "Upload CSV" and file_input is not None: |
|
|
df = pd.read_csv(file_input) |
|
|
if "Description" not in df.columns: |
|
|
return "β οΈ Error: CSV must contain a 'Description' column.", None |
|
|
descriptions = df["Description"].tolist() |
|
|
else: |
|
|
return "β οΈ Please provide input.", None |
|
|
|
|
|
results, csv_file = predict_tickets(descriptions) |
|
|
return results, csv_file |
|
|
|
|
|
def clear_inputs(): |
|
|
return "Enter Text", "", None, "", None |
|
|
|
|
|
|
|
|
|
|
|
custom_css = """ |
|
|
.gradio-container { max-width: 1100px !important; margin: auto !important; } |
|
|
#title { text-align: center; font-size: 24px !important; font-weight: bold; } |
|
|
#predict-button, #clear-button { width: 100% !important; height: 50px !important; font-size: 18px !important; } |
|
|
#results-box { height: 320px !important; overflow-y: auto !important; background: #f9f9f9; padding: 10px; border-radius: 8px; } |
|
|
""" |
|
|
|
|
|
|
|
|
with gr.Blocks(css=".gradio-container {max-width: 1100px; margin: auto;}") as app: |
|
|
gr.Markdown( |
|
|
""" |
|
|
# Multi-Ticket AI Classification System |
|
|
|
|
|
**Supports:** Multi-line text input & CSV upload. |
|
|
**Output:** Text results & downloadable CSV file. |
|
|
**Model:** Fine-tuned **RoBERTa** for classification. |
|
|
|
|
|
Enter ticket Description/Comment/Summary or upload a **CSV file** to predict Assigned Teams & Team Emails. |
|
|
|
|
|
""", |
|
|
elem_id="title" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1.5): |
|
|
option = gr.Radio(["Enter Text", "Upload CSV"], label="π Choose Input Method", value="Enter Text") |
|
|
|
|
|
text_input = gr.Textbox( |
|
|
label="Enter Ticket Description/Comment/Summary (One per line)", |
|
|
lines=6, |
|
|
placeholder="Example:\n - Database performance issue\n - Login fails for admin users..." |
|
|
) |
|
|
|
|
|
file_input = gr.File(label="π Upload CSV (Optional)", type="filepath", visible=False) |
|
|
|
|
|
with gr.Column(scale=1.5): |
|
|
gr.Markdown("## Prediction Results") |
|
|
results_output = gr.Markdown(elem_id="results-box", visible=True) |
|
|
download_csv = gr.File(label="π₯ Download Predictions CSV", interactive=False) |
|
|
|
|
|
with gr.Row(): |
|
|
predict_btn = gr.Button("PREDICT", variant="primary") |
|
|
clear_btn = gr.Button("CLEAR", variant="secondary") |
|
|
|
|
|
|
|
|
def toggle_input(selected_option): |
|
|
return gr.update(visible=(selected_option == "Enter Text")), gr.update(visible=(selected_option == "Upload CSV")) |
|
|
|
|
|
option.change(fn=toggle_input, inputs=[option], outputs=[text_input, file_input]) |
|
|
|
|
|
predict_btn.click(fn=gradio_predict, inputs=[option, text_input, file_input], outputs=[results_output, download_csv]) |
|
|
clear_btn.click(fn=clear_inputs, inputs=[], outputs=[option, text_input, file_input, results_output, download_csv]) |
|
|
|
|
|
|
|
|
gr.Markdown("---") |
|
|
gr.Markdown( |
|
|
"<p style='text-align: center;color: gray;'>" |
|
|
"Developed by NYP student @ Min Thein Win: Student ID: 3907578Y</p>" |
|
|
) |
|
|
|
|
|
|
|
|
app.launch(share=True) |
|
|
|