Daizzyy commited on
Commit
92a9089
·
verified ·
1 Parent(s): df4f89b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -41
app.py CHANGED
@@ -1,53 +1,109 @@
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
3
  import os
 
 
4
 
5
 
6
  def load_model():
7
- """Load model from root directory"""
8
- model_path = "."
9
-
10
  try:
11
- print(f"Loading model from root directory...")
12
- tokenizer = AutoTokenizer.from_pretrained(model_path)
13
- model = AutoModelForSequenceClassification.from_pretrained(model_path)
14
- classifier = pipeline(
15
- "text-classification",
16
- model=model,
17
- tokenizer=tokenizer,
18
- return_all_scores=True
19
- )
20
- print(f"✅ Successfully loaded model")
21
- return classifier
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  except Exception as e:
23
- print(f"❌ Error loading model: {str(e)}")
24
- classifier = pipeline(
25
- "text-classification",
26
- model="bert-base-uncased",
27
- return_all_scores=True
28
- )
29
- return classifier
30
 
31
- classifier = load_model()
32
 
33
- # ============================================================
34
- # PREDICTION LOGIC WITH CORRECT LABELS
35
- # ============================================================
36
 
37
  def predict(text):
38
- """Predict cyberbullying category"""
39
  if not text.strip():
40
  return "<div class='warn'>⚠️ Please enter some text.</div>"
41
 
42
  try:
43
- results = classifier(text)[0]
44
- best = max(results, key=lambda x: x["score"])
45
- label = best["label"]
46
- score = best["score"]
 
 
 
 
 
 
 
 
 
 
 
47
 
48
- print(f"Label: {label}, Score: {score}")
 
 
 
 
 
49
 
50
- # Your model labels
 
 
 
 
 
 
 
 
 
 
 
51
  cyberbullying_types = {
52
  "age": {"emoji": "👶", "color": "#ff6b6b", "text": "Age-Based Cyberbullying"},
53
  "gender": {"emoji": "⚥️", "color": "#ff8c42", "text": "Gender-Based Cyberbullying"},
@@ -57,11 +113,12 @@ def predict(text):
57
  "not_cyberbullying": {"emoji": "✅", "color": "#00ff64", "text": "Safe Message"}
58
  }
59
 
60
- # Get the category info
61
- category = cyberbullying_types.get(label.lower(), cyberbullying_types["not_cyberbullying"])
 
62
 
63
  # Safe message
64
- if label.lower() == "not_cyberbullying":
65
  return f"""
66
  <div class='safe'>
67
  <div class='checkmark'>{category['emoji']}</div>
@@ -88,13 +145,12 @@ def predict(text):
88
 
89
  except Exception as e:
90
  import traceback
 
91
  print(f"ERROR: {str(e)}")
92
- print(traceback.format_exc())
93
  return f"<div class='warn'>❌ Error: {str(e)}</div>"
94
 
95
- # ============================================================
96
- # GRADIO INTERFACE WITH PURPLE-TO-BLUE GRADIENT
97
- # ============================================================
98
 
99
  with gr.Blocks(theme=gr.themes.Soft(), css="""
100
  <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/animate.css/4.1.1/animate.min.css"/>
@@ -456,6 +512,6 @@ textarea:focus {
456
  outputs=output
457
  )
458
 
459
- # Launch the app
460
  if __name__ == "__main__":
461
  demo.launch()
 
1
  import gradio as gr
2
+ import joblib
3
  import os
4
+ import numpy as np
5
+
6
 
7
 
8
  def load_model():
9
+ """Load joblib model and components"""
 
 
10
  try:
11
+ print("Loading joblib model...")
12
+
13
+ # Load model (try different possible names)
14
+ model = None
15
+ model_file = None
16
+
17
+ if os.path.exists("model.safetensors"):
18
+ print("Found model.safetensors")
19
+ model = joblib.load("model.safetensors")
20
+ model_file = "model.safetensors"
21
+ elif os.path.exists("model.jobilib"):
22
+ print("Found model.jobilib")
23
+ model = joblib.load("model.jobilib")
24
+ model_file = "model.jobilib"
25
+ elif os.path.exists("tfidf_logreg_best.jobilib"):
26
+ print("Found tfidf_logreg_best.jobilib")
27
+ model = joblib.load("tfidf_logreg_best.jobilib")
28
+ model_file = "tfidf_logreg_best.jobilib"
29
+ else:
30
+ # List available files
31
+ files = os.listdir(".")
32
+ print(f"Available files: {files}")
33
+ raise FileNotFoundError("No model file found")
34
+
35
+ # Load vectorizer/tokenizer
36
+ vectorizer = None
37
+ if os.path.exists("vocab"):
38
+ print("Found vocab file")
39
+ vectorizer = joblib.load("vocab")
40
+ elif os.path.exists("tokenizer"):
41
+ print("Found tokenizer file")
42
+ vectorizer = joblib.load("tokenizer")
43
+
44
+ # Load label encoder
45
+ label_encoder = None
46
+ if os.path.exists("label_encoder.jobilib"):
47
+ print("Found label_encoder.jobilib")
48
+ label_encoder = joblib.load("label_encoder.jobilib")
49
+
50
+ print(f"✅ Model loaded successfully from {model_file}")
51
+ return {
52
+ "model": model,
53
+ "vectorizer": vectorizer,
54
+ "label_encoder": label_encoder
55
+ }
56
+
57
  except Exception as e:
58
+ print(f"❌ Error loading joblib model: {str(e)}")
59
+ return None
60
+
61
+ # Load model
62
+ model_components = load_model()
 
 
63
 
 
64
 
 
 
 
65
 
66
  def predict(text):
67
+ """Predict cyberbullying category using joblib model"""
68
  if not text.strip():
69
  return "<div class='warn'>⚠️ Please enter some text.</div>"
70
 
71
  try:
72
+ if model_components is None:
73
+ return "<div class='warn'>❌ Model not loaded properly</div>"
74
+
75
+ model = model_components["model"]
76
+ vectorizer = model_components["vectorizer"]
77
+ label_encoder = model_components["label_encoder"]
78
+
79
+ # Vectorize the text
80
+ if vectorizer is not None:
81
+ text_vector = vectorizer.transform([text])
82
+ else:
83
+ return "<div class='warn'>❌ Vectorizer not found</div>"
84
+
85
+ # Get prediction
86
+ prediction = model.predict(text_vector)[0]
87
 
88
+ # Get probability if available
89
+ try:
90
+ probabilities = model.predict_proba(text_vector)[0]
91
+ score = max(probabilities)
92
+ except:
93
+ score = 0.8 # Default score
94
 
95
+ # Decode label if encoder exists
96
+ if label_encoder is not None:
97
+ try:
98
+ label = label_encoder.inverse_transform([prediction])[0]
99
+ except:
100
+ label = str(prediction)
101
+ else:
102
+ label = str(prediction)
103
+
104
+ print(f"Prediction: {label}, Score: {score}")
105
+
106
+ # Category definitions
107
  cyberbullying_types = {
108
  "age": {"emoji": "👶", "color": "#ff6b6b", "text": "Age-Based Cyberbullying"},
109
  "gender": {"emoji": "⚥️", "color": "#ff8c42", "text": "Gender-Based Cyberbullying"},
 
113
  "not_cyberbullying": {"emoji": "✅", "color": "#00ff64", "text": "Safe Message"}
114
  }
115
 
116
+ # Get category (handle case variations)
117
+ label_lower = str(label).lower().strip()
118
+ category = cyberbullying_types.get(label_lower, cyberbullying_types.get(label, cyberbullying_types["not_cyberbullying"]))
119
 
120
  # Safe message
121
+ if label_lower == "not_cyberbullying":
122
  return f"""
123
  <div class='safe'>
124
  <div class='checkmark'>{category['emoji']}</div>
 
145
 
146
  except Exception as e:
147
  import traceback
148
+ error_msg = traceback.format_exc()
149
  print(f"ERROR: {str(e)}")
150
+ print(error_msg)
151
  return f"<div class='warn'>❌ Error: {str(e)}</div>"
152
 
153
+
 
 
154
 
155
  with gr.Blocks(theme=gr.themes.Soft(), css="""
156
  <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/animate.css/4.1.1/animate.min.css"/>
 
512
  outputs=output
513
  )
514
 
515
+
516
  if __name__ == "__main__":
517
  demo.launch()