mintheinwin's picture
update app file and extra other file
5f47787
raw
history blame
5.82 kB
import gradio as gr
import torch
import torch.nn as nn
import pickle
import pandas as pd
from transformers import RobertaTokenizerFast, RobertaModel
# Load label mappings
with open("label_mappings.pkl", "rb") as f:
label_mappings = pickle.load(f)
label_to_team = label_mappings.get("label_to_team", {})
label_to_email = label_mappings.get("label_to_email", {})
# Load the tokenizer
tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
# Define RoBERTa Model
class RoBertaClassifier(nn.Module):
def __init__(self, num_teams, num_emails):
super(RoBertaClassifier, self).__init__()
self.roberta = RobertaModel.from_pretrained("roberta-base")
self.team_classifier = nn.Linear(self.roberta.config.hidden_size, num_teams)
self.email_classifier = nn.Linear(self.roberta.config.hidden_size, num_emails)
def forward(self, input_ids, attention_mask):
outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask)
cls_output = outputs.last_hidden_state[:, 0, :]
team_logits = self.team_classifier(cls_output)
email_logits = self.email_classifier(cls_output)
return team_logits, email_logits
# Load Model
num_teams = len(label_to_team)
num_emails = len(label_to_email)
model = RoBertaClassifier(num_teams, num_emails)
checkpoint = torch.load("ticket_classification_model.pth", map_location=torch.device("cpu"))
filtered_checkpoint = {k: v for k, v in checkpoint.items() if k in model.state_dict()}
model.load_state_dict(filtered_checkpoint, strict=False)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)
model.eval()
# Prediction Function
def predict_tickets(ticket_descriptions):
predictions = []
csv_data = []
for idx, description in enumerate(ticket_descriptions, start=1):
inputs = tokenizer(description, return_tensors="pt", truncation=True, padding="max_length", max_length=128).to(device)
with torch.no_grad():
team_logits, email_logits = model(inputs.input_ids, inputs.attention_mask)
predicted_team_index = team_logits.argmax(dim=-1).cpu().item()
predicted_email_index = email_logits.argmax(dim=-1).cpu().item()
predicted_team = label_to_team.get(predicted_team_index, "Unknown Team")
predicted_email = label_to_email.get(predicted_email_index, "Unknown Email")
predictions.append(f"**{idx}. {description}**\n - **Assigned Team:** {predicted_team}\n - **Team Email:** {predicted_email}\n")
csv_data.append([idx, description, predicted_team, predicted_email])
df = pd.DataFrame(csv_data, columns=["Index", "Description", "Assigned Team", "Team Email"])
csv_file = "ticket-predictions.csv"
df.to_csv(csv_file, index=False)
return "\n".join(predictions), csv_file
# Gradio Functions
def gradio_predict(option, text_input, file_input):
if option == "Enter Text":
descriptions = text_input.split("\n")
descriptions = [desc.strip() for desc in descriptions if desc.strip()]
elif option == "Upload CSV" and file_input is not None:
df = pd.read_csv(file_input)
if "Description" not in df.columns:
return "⚠️ Error: CSV must contain a 'Description' column.", None
descriptions = df["Description"].tolist()
else:
return "⚠️ Please provide input.", None
results, csv_file = predict_tickets(descriptions)
return results, csv_file
def clear_inputs():
return "Enter Text", "", None, "", None
# Gradio App UI
with gr.Blocks(css=".gradio-container {max-width: 1100px; margin: auto;}") as app:
gr.Markdown(
"""
# Multi-Ticket AI Classification System
**Supports:** Multi-line text input & CSV upload.
**Output:** Text results & downloadable CSV file.
**Model:** Fine-tuned **RoBERTa** for classification.
Enter ticket Description/Comment/Summary or upload a **CSV file** to predict Assigned Teams & Team Emails.
""",
elem_id="title"
)
with gr.Row():
with gr.Column(scale=1):
option = gr.Radio(["Enter Text", "Upload CSV"], label="πŸ“ Choose Input Method", value="Enter Text")
text_input = gr.Textbox(
label="Enter Ticket Description/Comment/Summary (One per line)",
lines=6,
placeholder="Example:\n - Database performance issue\n - Login fails for admin users..."
)
file_input = gr.File(label="πŸ“‚ Upload CSV (Optional)", type="filepath", visible=False)
with gr.Column(scale=1):
gr.Markdown("## Prediction Results") # **Title for Prediction Results**
results_output = gr.Markdown(elem_id="results-box", visible=True)
download_csv = gr.File(label="πŸ“₯ Download Predictions CSV", interactive=False)
with gr.Row():
predict_btn = gr.Button("PREDICT", variant="primary")
clear_btn = gr.Button("CLEAR", variant="secondary")
# Logic for Showing/ Hiding Input Fields
def toggle_input(selected_option):
return gr.update(visible=(selected_option == "Enter Text")), gr.update(visible=(selected_option == "Upload CSV"))
option.change(fn=toggle_input, inputs=[option], outputs=[text_input, file_input])
predict_btn.click(fn=gradio_predict, inputs=[option, text_input, file_input], outputs=[results_output, download_csv])
clear_btn.click(fn=clear_inputs, inputs=[], outputs=[option, text_input, file_input, results_output, download_csv])
# Footer view
gr.Markdown("---")
gr.Markdown(
"<p style='text-align: center;color: gray;'>"
"Developed by NYP student @ Min Thein Win: Student ID: 3907578Y</p>"
)
# Launch App
app.launch(share=True)