import streamlit as st import torch import torch.nn as nn import pickle import pandas as pd from transformers import RobertaTokenizerFast, RobertaModel # ------------------------------- # Load label mappings with open("label_mappings.pkl", "rb") as f: label_mappings = pickle.load(f) label_to_team = label_mappings.get("label_to_team", {}) label_to_email = label_mappings.get("label_to_email", {}) # ------------------------------- # Load tokenizer tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base") # ------------------------------- # Define RoBERTa Model for multi-task classification class RoBertaClassifier(nn.Module): def __init__(self, num_teams, num_emails): super(RoBertaClassifier, self).__init__() self.roberta = RobertaModel.from_pretrained("roberta-base") self.team_classifier = nn.Linear(self.roberta.config.hidden_size, num_teams) self.email_classifier = nn.Linear(self.roberta.config.hidden_size, num_emails) def forward(self, input_ids, attention_mask): outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask) cls_output = outputs.last_hidden_state[:, 0, :] team_logits = self.team_classifier(cls_output) email_logits = self.email_classifier(cls_output) return team_logits, email_logits # ------------------------------- # Initialize model and load checkpoint num_teams = len(label_to_team) num_emails = len(label_to_email) model = RoBertaClassifier(num_teams, num_emails) checkpoint = torch.load("ticket_classification_model.pth", map_location=torch.device("cpu")) filtered_checkpoint = {k: v for k, v in checkpoint.items() if k in model.state_dict()} model.load_state_dict(filtered_checkpoint, strict=False) device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") model.to(device) model.eval() # ------------------------------- # Prediction function def predict_tickets(ticket_descriptions): predictions = [] csv_data = [] for idx, description in enumerate(ticket_descriptions, start=1): inputs = tokenizer( description, return_tensors="pt", truncation=True, padding="max_length", max_length=128 ).to(device) with torch.no_grad(): team_logits, email_logits = model(inputs.input_ids, inputs.attention_mask) predicted_team_index = team_logits.argmax(dim=-1).cpu().item() predicted_email_index = email_logits.argmax(dim=-1).cpu().item() predicted_team = label_to_team.get(predicted_team_index, "Unknown Team") predicted_email = label_to_email.get(predicted_email_index, "Unknown Email") predictions.append( f"{idx}. {description}\n - Assigned Team: {predicted_team}\n - Team Email: {predicted_email}\n" ) csv_data.append([idx, description, predicted_team, predicted_email]) df = pd.DataFrame(csv_data, columns=["Index", "Description", "Assigned Team", "Team Email"]) return "\n".join(predictions), df # ------------------------------- # Streamlit UI #st.title("AI Solution for Defect Ticket Classification") st.markdown("

AI Solution for Defect Ticket Classification

", unsafe_allow_html=True) st.markdown(""" **Supports:** Multi-line text input & CSV upload. **Output:** Text results & downloadable CSV file. **Model:** Fine-tuned **RoBERTa** for classification. """) # Choose input method option = st.radio("📝 Choose Input Method", ["Enter Text", "Upload CSV"]) if option == "Enter Text": text_input = st.text_area( "Enter Ticket Description/Comment/Summary (One per line)", placeholder="Example:\n - Database performance issue\n - Login fails for admin users..." ) descriptions = [line.strip() for line in text_input.split("\n") if line.strip()] else: file_input = st.file_uploader("Upload CSV", type=["csv"]) descriptions = [] if file_input is not None: df_input = pd.read_csv(file_input) if "Description" not in df_input.columns: st.error("⚠️ Error: CSV must contain a 'Description' column.") else: descriptions = df_input["Description"].dropna().tolist() # Trigger prediction when the button is clicked if st.button("PREDICT"): if not descriptions: st.error("⚠️ Please provide valid input.") else: with st.spinner("Predicting..."): results, df_results = predict_tickets(descriptions) st.markdown("## Prediction Results") st.text(results) csv_data = df_results.to_csv(index=False).encode('utf-8') st.download_button( label="📥 Download Predictions CSV", data=csv_data, file_name="ticket-predictions.csv", mime="text/csv" ) # Clear button: simply reloads the app if st.button("CLEAR"): st.experimental_rerun() st.markdown("---") st.markdown( "

Developed by NYP student @ Min Thein Win: Student ID: 3907578Y

", unsafe_allow_html=True )