mintheinwin's picture
update app
120cd79
raw
history blame
5.13 kB
import streamlit as st
import torch
import torch.nn as nn
import pickle
import pandas as pd
from transformers import RobertaTokenizerFast, RobertaModel
# -------------------------------
# Load label mappings
with open("label_mappings.pkl", "rb") as f:
label_mappings = pickle.load(f)
label_to_team = label_mappings.get("label_to_team", {})
label_to_email = label_mappings.get("label_to_email", {})
# -------------------------------
# Load tokenizer
tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
# -------------------------------
# Define RoBERTa Model for multi-task classification
class RoBertaClassifier(nn.Module):
def __init__(self, num_teams, num_emails):
super(RoBertaClassifier, self).__init__()
self.roberta = RobertaModel.from_pretrained("roberta-base")
self.team_classifier = nn.Linear(self.roberta.config.hidden_size, num_teams)
self.email_classifier = nn.Linear(self.roberta.config.hidden_size, num_emails)
def forward(self, input_ids, attention_mask):
outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask)
cls_output = outputs.last_hidden_state[:, 0, :]
team_logits = self.team_classifier(cls_output)
email_logits = self.email_classifier(cls_output)
return team_logits, email_logits
# -------------------------------
# Initialize model and load checkpoint
num_teams = len(label_to_team)
num_emails = len(label_to_email)
model = RoBertaClassifier(num_teams, num_emails)
checkpoint = torch.load("ticket_classification_model.pth", map_location=torch.device("cpu"))
filtered_checkpoint = {k: v for k, v in checkpoint.items() if k in model.state_dict()}
model.load_state_dict(filtered_checkpoint, strict=False)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)
model.eval()
# -------------------------------
# Prediction function
def predict_tickets(ticket_descriptions):
predictions = []
csv_data = []
for idx, description in enumerate(ticket_descriptions, start=1):
inputs = tokenizer(
description,
return_tensors="pt",
truncation=True,
padding="max_length",
max_length=128
).to(device)
with torch.no_grad():
team_logits, email_logits = model(inputs.input_ids, inputs.attention_mask)
predicted_team_index = team_logits.argmax(dim=-1).cpu().item()
predicted_email_index = email_logits.argmax(dim=-1).cpu().item()
predicted_team = label_to_team.get(predicted_team_index, "Unknown Team")
predicted_email = label_to_email.get(predicted_email_index, "Unknown Email")
predictions.append(
f"{idx}. {description}\n - Assigned Team: {predicted_team}\n - Team Email: {predicted_email}\n"
)
csv_data.append([idx, description, predicted_team, predicted_email])
df = pd.DataFrame(csv_data, columns=["Index", "Description", "Assigned Team", "Team Email"])
return "\n".join(predictions), df
# -------------------------------
# Streamlit UI
#st.title("AI Solution for Defect Ticket Classification")
st.markdown("<h2 style='text-align: center;'>AI Solution for Defect Ticket Classification</h2>", unsafe_allow_html=True)
st.markdown("""
**Supports:** Multi-line text input & CSV upload.
**Output:** Text results & downloadable CSV file.
**Model:** Fine-tuned **RoBERTa** for classification.
""")
# Choose input method
option = st.radio("πŸ“ Choose Input Method", ["Enter Text", "Upload CSV"])
if option == "Enter Text":
text_input = st.text_area(
"Enter Ticket Description/Comment/Summary (One per line)",
placeholder="Example:\n - Database performance issue\n - Login fails for admin users..."
)
descriptions = [line.strip() for line in text_input.split("\n") if line.strip()]
else:
file_input = st.file_uploader("Upload CSV", type=["csv"])
descriptions = []
if file_input is not None:
df_input = pd.read_csv(file_input)
if "Description" not in df_input.columns:
st.error("⚠️ Error: CSV must contain a 'Description' column.")
else:
descriptions = df_input["Description"].dropna().tolist()
# Trigger prediction when the button is clicked
if st.button("PREDICT"):
if not descriptions:
st.error("⚠️ Please provide valid input.")
else:
with st.spinner("Predicting..."):
results, df_results = predict_tickets(descriptions)
st.markdown("## Prediction Results")
st.text(results)
csv_data = df_results.to_csv(index=False).encode('utf-8')
st.download_button(
label="πŸ“₯ Download Predictions CSV",
data=csv_data,
file_name="ticket-predictions.csv",
mime="text/csv"
)
# Clear button: simply reloads the app
if st.button("CLEAR"):
st.experimental_rerun()
st.markdown("---")
st.markdown(
"<p style='text-align: center;color: gray;'>Developed by NYP student @ Min Thein Win: Student ID: 3907578Y</p>",
unsafe_allow_html=True
)