ZunYin commited on
Commit
90fdda0
Β·
verified Β·
1 Parent(s): 83015cc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -83
app.py CHANGED
@@ -1,19 +1,8 @@
1
  import torch
2
  import joblib
3
  import gradio as gr
4
- import matplotlib.pyplot as plt
5
- import numpy as np
6
- import shap
7
- import lime
8
- import lime.lime_text
9
- import logging
10
- import os
11
 
12
- # Configure logging for debugging
13
- LOG_FILE = "prediction_logs.txt"
14
- logging.basicConfig(filename=LOG_FILE, level=logging.INFO, format="%(asctime)s - %(message)s")
15
-
16
- # Load Model & Preprocessors
17
  tfidf_vectorizer = joblib.load("tfidf_vectorizer.pkl")
18
  category_encoder = joblib.load("category_encoder.pkl")
19
  team_encoder = joblib.load("team_encoder.pkl")
@@ -21,7 +10,7 @@ team_encoder = joblib.load("team_encoder.pkl")
21
  multi_label_classifier = torch.load("multi_label_classifier.pth")
22
  multi_label_classifier.eval()
23
 
24
- # Dummy function to get keywords per category
25
  def get_top_keywords_per_category(category, n=5):
26
  keywords_dict = {
27
  "UX Issue": ["mobile", "responsive", "alignment", "css", "layout"],
@@ -30,83 +19,31 @@ def get_top_keywords_per_category(category, n=5):
30
  }
31
  return keywords_dict.get(category, ["No keywords found"])
32
 
33
- # Function to explain predictions using SHAP
34
- def explain_with_shap(phrase):
35
- background_data = ["UI button is not working", "Server error while processing", "Page alignment issue"]
36
- feature_data = tfidf_vectorizer.transform(background_data + [phrase])
37
-
38
- explainer = shap.Explainer(multi_label_classifier, feature_data[:-1])
39
- shap_values = explainer(feature_data[-1])
40
-
41
- # Generate SHAP force plot
42
- shap_html = shap.plots.text(shap_values, display=False)
43
- return shap_html
44
-
45
- # Function to predict, explain, and generate logs
46
- def predict_with_visuals(phrase):
47
  text_features = tfidf_vectorizer.transform([phrase])
48
-
49
- predictions = multi_label_classifier.predict_proba(text_features)
50
- predicted_labels = np.argmax(predictions, axis=1)
51
-
52
- predicted_category = category_encoder.inverse_transform([predicted_labels[0]])[0]
53
- predicted_team = team_encoder.inverse_transform([predicted_labels[0]])[0]
54
- team_email = f"support@{predicted_team.replace(' ', '').lower()}.com"
55
 
 
 
 
56
  keywords = get_top_keywords_per_category(predicted_category, n=5)
57
 
58
- # Create Bar Chart for Confidence Scores
59
- category_names = category_encoder.classes_
60
- category_probs = predictions[0]
61
-
62
- fig, ax = plt.subplots(facecolor='#121212')
63
- ax.barh(category_names, category_probs, color=["#ff9999", "#66b3ff", "#99ff99"])
64
- ax.set_xlabel("Confidence Score", color="white")
65
- ax.set_title("Prediction Confidence by Category", color="white")
66
- ax.tick_params(colors="white")
67
-
68
- # Log the prediction
69
- log_entry = f"Input: {phrase} | Predicted Category: {predicted_category} | Confidence: {category_probs.max():.2f}"
70
- logging.info(log_entry)
71
-
72
- # SHAP explanation
73
- shap_explanation = explain_with_shap(phrase)
74
-
75
- result = f"""
76
- <div style='font-size: 18px; font-family: Arial; color: white; background-color: #121212; padding: 10px; border-radius: 8px;'>
77
- <strong>πŸ“Œ Predicted Category:</strong> <span style='color:#4CAF50;'>{predicted_category}</span><br>
78
- <strong>πŸ‘¨β€πŸ’» Assigned Team:</strong> <span style='color:#2196F3;'>{predicted_team}</span><br>
79
- <strong>πŸ“§ Team Email:</strong> <span style='color:#FF5722;'>{team_email}</span><br>
80
- <strong>πŸ”‘ Top Keywords:</strong> <span style='color:#FFEB3B;'>{', '.join(keywords)}</span>
81
- </div>
82
  """
83
 
84
- return result, fig, shap_explanation
85
-
86
- # Function to download logs
87
- def download_logs():
88
- with open(LOG_FILE, "r") as f:
89
- return f.read()
90
-
91
- # Gradio Interface with SHAP, Logs, and Faster Inference
92
  interface = gr.Interface(
93
- fn=predict_with_visuals,
94
  inputs=gr.Textbox(lines=2, placeholder="Enter defect description..."),
95
- outputs=["html", "plot", "html"],
96
- title="πŸ”Ž AI Defect Ticket Classifier",
97
- description="Enter a defect description to predict its **Category, Assigned Team, and relevant Keywords**. See **SHAP explanations, logs, and confidence scores!**",
98
- theme="dark"
99
- )
100
-
101
- # Add logs download button
102
- download_interface = gr.Interface(
103
- fn=download_logs,
104
- inputs=[],
105
- outputs="text",
106
- title="πŸ“‚ Download Prediction Logs",
107
- description="Click the button to download all logs.",
108
  )
109
 
110
- # Launch both interfaces
111
- interface.launch(share=True)
112
- download_interface.launch(share=True)
 
1
  import torch
2
  import joblib
3
  import gradio as gr
 
 
 
 
 
 
 
4
 
5
+ # Load the Model and Dependencies
 
 
 
 
6
  tfidf_vectorizer = joblib.load("tfidf_vectorizer.pkl")
7
  category_encoder = joblib.load("category_encoder.pkl")
8
  team_encoder = joblib.load("team_encoder.pkl")
 
10
  multi_label_classifier = torch.load("multi_label_classifier.pth")
11
  multi_label_classifier.eval()
12
 
13
+ # Function to get keywords (dummy implementation)
14
  def get_top_keywords_per_category(category, n=5):
15
  keywords_dict = {
16
  "UX Issue": ["mobile", "responsive", "alignment", "css", "layout"],
 
19
  }
20
  return keywords_dict.get(category, ["No keywords found"])
21
 
22
+ # Prediction function
23
+ def predict_with_keywords(phrase):
 
 
 
 
 
 
 
 
 
 
 
 
24
  text_features = tfidf_vectorizer.transform([phrase])
25
+ predicted_labels = multi_label_classifier.predict(text_features)
 
 
 
 
 
 
26
 
27
+ predicted_category = category_encoder.inverse_transform([predicted_labels[0][0]])[0]
28
+ predicted_team = team_encoder.inverse_transform([predicted_labels[0][1]])[0]
29
+ team_email = "support@" + predicted_team.replace(" ", "").lower() + ".com"
30
  keywords = get_top_keywords_per_category(predicted_category, n=5)
31
 
32
+ return f"""
33
+ **Predicted Category:** {predicted_category}
34
+ **Predicted Assigned Team:** {predicted_team}
35
+ **Team Email:** {team_email}
36
+ **Top Keywords:** {', '.join(keywords)}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  """
38
 
39
+ # Gradio Interface
 
 
 
 
 
 
 
40
  interface = gr.Interface(
41
+ fn=predict_with_keywords,
42
  inputs=gr.Textbox(lines=2, placeholder="Enter defect description..."),
43
+ outputs="markdown",
44
+ title="Defect Ticket Classifier",
45
+ description="Enter a defect description to predict its Category, Assigned Team, and relevant Keywords."
 
 
 
 
 
 
 
 
 
 
46
  )
47
 
48
+ # Launch app
49
+ interface.launch()