Agnist commited on
Commit
bb3f86e
·
verified ·
1 Parent(s): 93debb2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -43
app.py CHANGED
@@ -14,51 +14,45 @@ import plotly.express as px
14
  import plotly.graph_objects as go
15
  import warnings
16
 
17
- # Suppress warnings
18
  warnings.filterwarnings("ignore")
19
 
20
- # Load dataset
21
  print("Loading dataset...")
22
  ds = load_dataset("uhoui/text-tone-classifier")
23
 
24
- # Convert to pandas DataFrame
25
  df = pd.DataFrame(ds["train"])
26
 
27
- # Print dataset statistics
28
  print(f"Dataset size: {len(df)} entries")
29
  print(f"Columns: {df.columns}")
30
 
31
- # Check class distribution
32
  label_counts = df['label'].value_counts()
33
  print("\nClass distribution:")
34
  print(label_counts)
35
 
36
- # Encode labels
37
  label_encoder = LabelEncoder()
38
  df['label_encoded'] = label_encoder.fit_transform(df['label'])
39
  num_classes = len(label_encoder.classes_)
40
 
41
- # Split the data - fix: remove stratify for classes with few samples
42
  X_train, X_test, y_train, y_test = train_test_split(
43
  df['text'],
44
  df['label_encoded'],
45
  test_size=0.2,
46
  random_state=42,
47
- # Only use stratify if we have enough samples
48
- stratify=None # Removed stratification to fix the error
49
  )
50
 
51
- # Feature extraction using TF-IDF
52
  print("Creating TF-IDF features...")
53
  tfidf = TfidfVectorizer(max_features=5000)
54
  X_train_tfidf = tfidf.fit_transform(X_train)
55
  X_test_tfidf = tfidf.transform(X_test)
56
 
57
- # Handle class imbalance using SMOTE - Fix for SMOTE error
58
- print("Applying SMOTE to handle class imbalance...")
59
  try:
60
- # Modify the SMOTE parameters to handle small sample sizes
61
- # Use k_neighbors=min(5, n_samples-1) for classes with few samples
62
  smallest_class_size = min(np.bincount(y_train)[np.bincount(y_train) > 0])
63
  k_neighbors = min(5, smallest_class_size - 1)
64
 
@@ -73,46 +67,44 @@ except ValueError as e:
73
  print(f"SMOTE error: {e}. Using original data.")
74
  X_train_resampled, y_train_resampled = X_train_tfidf, y_train
75
 
76
- # Train a logistic regression model
77
  print("Training model...")
78
  model = LogisticRegression(C=10, max_iter=1000, n_jobs=-1, solver='lbfgs', multi_class='multinomial')
79
  model.fit(X_train_resampled, y_train_resampled)
80
 
81
- # Evaluate model
82
  y_pred = model.predict(X_test_tfidf)
83
  accuracy = accuracy_score(y_test, y_pred)
84
  print(f"Model accuracy: {accuracy:.4f}")
85
 
86
- # Function to predict tone with probabilities
87
  def predict_tone(text):
88
- # Vectorize the input text
89
  text_tfidf = tfidf.transform([text])
90
 
91
  # Get prediction probabilities
92
  probs = model.predict_proba(text_tfidf)[0]
93
 
94
- # Get the predicted class and its probability
95
  pred_class_idx = np.argmax(probs)
96
  pred_class = label_encoder.inverse_transform([pred_class_idx])[0]
97
 
98
- # Get the labels used in training
99
- trained_labels = model.classes_ # These are encoded label indices
100
 
101
- # Convert encoded labels back to original string labels
102
  trained_label_names = label_encoder.inverse_transform(trained_labels)
103
 
104
- # Create results dictionary with only trained labels
105
  results = {label: float(prob) for label, prob in zip(trained_label_names, probs)}
106
 
107
  # Sort results by probability (descending)
108
  sorted_results = {k: v for k, v in sorted(results.items(), key=lambda item: item[1], reverse=True)}
109
 
110
  # Create visualization
111
- top_n = 5 # Show top 5 emotions
112
  top_labels = list(sorted_results.keys())[:top_n]
113
  top_probs = list(sorted_results.values())[:top_n]
114
 
115
- # Generate colors based on probability (higher probability = more intense color)
116
  colors = ["rgba(64, 128, 255, " + str(min(1.0, p + 0.3)) + ")" for p in top_probs]
117
 
118
  fig = go.Figure()
@@ -134,12 +126,11 @@ def predict_tone(text):
134
  xaxis=dict(range=[0, 1])
135
  )
136
 
137
- # Get example texts for the predicted emotion
138
  example_texts = df[df['label'] == pred_class]['text'].sample(min(3, len(df[df['label'] == pred_class]))).tolist()
139
 
140
  return pred_class, sorted_results, fig, example_texts
141
 
142
- # Function to handle the example display
143
  def get_tone_examples(tone):
144
  examples = df[df['label'] == tone]['text'].sample(min(5, len(df[df['label'] == tone]))).tolist()
145
  return examples
@@ -147,40 +138,37 @@ def get_tone_examples(tone):
147
  # Gradio interface
148
  def analyze_tone(text, selected_tone=None):
149
  if not text:
150
- return "Please enter some text to analyze.", {}, None, []
151
 
152
- # If a tone is selected from the dropdown, show examples
153
  if selected_tone and not text:
154
  examples = get_tone_examples(selected_tone)
155
  return f"Examples of '{selected_tone}' tone:", {}, None, examples
156
 
157
- # Otherwise, analyze the text
158
  predicted_tone, all_probs, fig, examples = predict_tone(text)
159
 
160
- # Format the result message
161
- message = f"The predicted tone is: **{predicted_tone}**"
162
 
163
  return message, all_probs, fig, examples
164
 
165
- # Create the Gradio interface
166
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
167
  gr.Markdown("# Text Tone Analyzer")
168
- gr.Markdown("Enter text to analyze its emotional tone.")
169
 
170
  with gr.Row():
171
  with gr.Column(scale=3):
172
  text_input = gr.Textbox(
173
  label="Enter your text here",
174
- placeholder="Type something to analyze its emotional tone...",
175
  lines=5
176
  )
177
  analyze_button = gr.Button("Analyze Tone", variant="primary")
178
 
179
  with gr.Column(scale=2):
180
- # Dropdown to select example tones
181
  tone_dropdown = gr.Dropdown(
182
  choices=sorted(df['label'].unique().tolist()),
183
- label="Or select a tone to see examples"
184
  )
185
 
186
  with gr.Row():
@@ -200,21 +188,19 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
200
  label="Example texts with similar tone"
201
  )
202
 
203
- # Fix for the click event handler - properly list the inputs
204
  analyze_button.click(
205
  fn=analyze_tone,
206
- inputs=[text_input, tone_dropdown], # Fixed: explicitly list both inputs
207
  outputs=[result_message, all_probs_output, plot_output, examples_output]
208
  )
209
 
210
- # Fix for tone_dropdown event handler
211
  tone_dropdown.change(
212
  fn=get_tone_examples,
213
- inputs=tone_dropdown, # This also needs to be fixed to be a list
214
  outputs=examples_output
215
  )
216
 
217
- # Add example inputs
218
  examples = [
219
  ["I'm so excited about this new project!"],
220
  ["I'm feeling quite down today and nothing seems to work."],
@@ -223,6 +209,6 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
223
  ]
224
  gr.Examples(examples=examples, inputs=text_input)
225
 
226
- # Launch the app
227
  if __name__ == "__main__":
228
  demo.launch()
 
14
  import plotly.graph_objects as go
15
  import warnings
16
 
 
17
  warnings.filterwarnings("ignore")
18
 
19
+ # Hugging face dataset import
20
  print("Loading dataset...")
21
  ds = load_dataset("uhoui/text-tone-classifier")
22
 
 
23
  df = pd.DataFrame(ds["train"])
24
 
25
+ # Console Log dataset and class
26
  print(f"Dataset size: {len(df)} entries")
27
  print(f"Columns: {df.columns}")
28
 
 
29
  label_counts = df['label'].value_counts()
30
  print("\nClass distribution:")
31
  print(label_counts)
32
 
33
+ # Labels
34
  label_encoder = LabelEncoder()
35
  df['label_encoded'] = label_encoder.fit_transform(df['label'])
36
  num_classes = len(label_encoder.classes_)
37
 
38
+ # Train testsplit
39
  X_train, X_test, y_train, y_test = train_test_split(
40
  df['text'],
41
  df['label_encoded'],
42
  test_size=0.2,
43
  random_state=42,
44
+ stratify=None
 
45
  )
46
 
47
+ # TFIDF Feature extraction
48
  print("Creating TF-IDF features...")
49
  tfidf = TfidfVectorizer(max_features=5000)
50
  X_train_tfidf = tfidf.fit_transform(X_train)
51
  X_test_tfidf = tfidf.transform(X_test)
52
 
53
+ # SMOTE
54
+ print("Handling class imbalance (via SNOTE)...")
55
  try:
 
 
56
  smallest_class_size = min(np.bincount(y_train)[np.bincount(y_train) > 0])
57
  k_neighbors = min(5, smallest_class_size - 1)
58
 
 
67
  print(f"SMOTE error: {e}. Using original data.")
68
  X_train_resampled, y_train_resampled = X_train_tfidf, y_train
69
 
70
+ # Logistic Regression Model
71
  print("Training model...")
72
  model = LogisticRegression(C=10, max_iter=1000, n_jobs=-1, solver='lbfgs', multi_class='multinomial')
73
  model.fit(X_train_resampled, y_train_resampled)
74
 
75
+ # Evaluate Model
76
  y_pred = model.predict(X_test_tfidf)
77
  accuracy = accuracy_score(y_test, y_pred)
78
  print(f"Model accuracy: {accuracy:.4f}")
79
 
 
80
  def predict_tone(text):
81
+ # Vectorize
82
  text_tfidf = tfidf.transform([text])
83
 
84
  # Get prediction probabilities
85
  probs = model.predict_proba(text_tfidf)[0]
86
 
87
+ # Get predicted class and its probability
88
  pred_class_idx = np.argmax(probs)
89
  pred_class = label_encoder.inverse_transform([pred_class_idx])[0]
90
 
91
+ # Get the labels used during training
92
+ trained_labels = model.classes_
93
 
94
+ # Decode to string (Labels)
95
  trained_label_names = label_encoder.inverse_transform(trained_labels)
96
 
 
97
  results = {label: float(prob) for label, prob in zip(trained_label_names, probs)}
98
 
99
  # Sort results by probability (descending)
100
  sorted_results = {k: v for k, v in sorted(results.items(), key=lambda item: item[1], reverse=True)}
101
 
102
  # Create visualization
103
+ top_n = 5 # Top 5, adjust later if needed
104
  top_labels = list(sorted_results.keys())[:top_n]
105
  top_probs = list(sorted_results.values())[:top_n]
106
 
107
+ # OPTIONAL: color-code probabilities
108
  colors = ["rgba(64, 128, 255, " + str(min(1.0, p + 0.3)) + ")" for p in top_probs]
109
 
110
  fig = go.Figure()
 
126
  xaxis=dict(range=[0, 1])
127
  )
128
 
129
+ # Fetch examples
130
  example_texts = df[df['label'] == pred_class]['text'].sample(min(3, len(df[df['label'] == pred_class]))).tolist()
131
 
132
  return pred_class, sorted_results, fig, example_texts
133
 
 
134
  def get_tone_examples(tone):
135
  examples = df[df['label'] == tone]['text'].sample(min(5, len(df[df['label'] == tone]))).tolist()
136
  return examples
 
138
  # Gradio interface
139
  def analyze_tone(text, selected_tone=None):
140
  if not text:
141
+ return "Enter the text to analyze:", {}, None, []
142
 
 
143
  if selected_tone and not text:
144
  examples = get_tone_examples(selected_tone)
145
  return f"Examples of '{selected_tone}' tone:", {}, None, examples
146
 
 
147
  predicted_tone, all_probs, fig, examples = predict_tone(text)
148
 
149
+ message = f"The tone is: **{predicted_tone}**"
 
150
 
151
  return message, all_probs, fig, examples
152
 
153
+ # Gradio interface Creation
154
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
155
  gr.Markdown("# Text Tone Analyzer")
156
+ gr.Markdown("Enter the text to analyze:")
157
 
158
  with gr.Row():
159
  with gr.Column(scale=3):
160
  text_input = gr.Textbox(
161
  label="Enter your text here",
162
+ placeholder="Example: The satisfaction of completing a difficult puzzle is indescribable.",
163
  lines=5
164
  )
165
  analyze_button = gr.Button("Analyze Tone", variant="primary")
166
 
167
  with gr.Column(scale=2):
168
+ # Example Tones Dropdown
169
  tone_dropdown = gr.Dropdown(
170
  choices=sorted(df['label'].unique().tolist()),
171
+ label="Select a tone to view an example below."
172
  )
173
 
174
  with gr.Row():
 
188
  label="Example texts with similar tone"
189
  )
190
 
 
191
  analyze_button.click(
192
  fn=analyze_tone,
193
+ inputs=[text_input, tone_dropdown],
194
  outputs=[result_message, all_probs_output, plot_output, examples_output]
195
  )
196
 
 
197
  tone_dropdown.change(
198
  fn=get_tone_examples,
199
+ inputs=tone_dropdown,
200
  outputs=examples_output
201
  )
202
 
203
+ # Example inputs
204
  examples = [
205
  ["I'm so excited about this new project!"],
206
  ["I'm feeling quite down today and nothing seems to work."],
 
209
  ]
210
  gr.Examples(examples=examples, inputs=text_input)
211
 
212
+ # Main
213
  if __name__ == "__main__":
214
  demo.launch()