douglasgoodwin commited on
Commit
74759ef
·
verified ·
1 Parent(s): 0ef87ae

try ChatGPT edit

Browse files
Files changed (1) hide show
  1. app.py +16 -31
app.py CHANGED
@@ -17,48 +17,36 @@ logger = logging.getLogger(__name__)
17
  logger.info("Initializing emotion classification pipeline...")
18
  classifier = pipeline(
19
  "text-classification",
20
- model="bhadresh-savani/distilbert-base-uncased-emotion",
21
- top_k=None # Return all scores
22
  )
23
  logger.info("Pipeline initialized successfully")
24
 
 
 
 
 
 
 
 
 
25
  def predict_emotion(text):
26
  """Predict emotions from text and return formatted results."""
27
  try:
28
- logger.info(f"Received input text: {text}")
29
-
30
  if not text:
31
  logger.warning("Empty text received")
32
- return gr.BarPlot.update(
33
- value=pd.DataFrame({
34
- 'label': ['sadness', 'joy', 'love', 'anger', 'fear', 'surprise'],
35
- 'score': [0, 0, 0, 0, 0, 0]
36
- })
37
- )
38
 
39
- # Get predictions
40
  logger.info("Running prediction...")
41
  predictions = classifier(text)[0]
42
- logger.info(f"Raw predictions: {predictions}")
43
-
44
- # Sort predictions by score
45
- sorted_predictions = sorted(predictions, key=lambda x: x['score'], reverse=True)
46
-
47
- # Create DataFrame with the correct format
48
- df = pd.DataFrame(sorted_predictions)
49
- df.columns = ['label', 'score'] # Rename columns to match expected format
50
  logger.info(f"Processed scores:\n{df}")
51
 
52
  return gr.BarPlot.update(value=df)
53
 
54
  except Exception as e:
55
  logger.error(f"Error in prediction: {str(e)}")
56
- return gr.BarPlot.update(
57
- value=pd.DataFrame({
58
- 'label': ['error'],
59
- 'score': [1.0]
60
- })
61
- )
62
 
63
  # Create the Gradio interface
64
  demo = gr.Interface(
@@ -69,14 +57,11 @@ demo = gr.Interface(
69
  lines=4
70
  ),
71
  outputs=gr.BarPlot(
72
- value=pd.DataFrame({
73
- 'label': ['sadness', 'joy', 'love', 'anger', 'fear', 'surprise'],
74
- 'score': [0, 0, 0, 0, 0, 0]
75
- }),
76
  x="label",
77
  y="score",
78
  title="Emotion Probabilities",
79
- color="#2563eb",
80
  height=400,
81
  vertical=True
82
  ),
@@ -93,4 +78,4 @@ demo = gr.Interface(
93
  )
94
 
95
  if __name__ == "__main__":
96
- demo.launch(debug=True)
 
17
  logger.info("Initializing emotion classification pipeline...")
18
  classifier = pipeline(
19
  "text-classification",
20
+ model="bhadresh-savani/distilbert-base-uncased-emotion"
 
21
  )
22
  logger.info("Pipeline initialized successfully")
23
 
24
+ # Default DataFrame for BarPlot
25
+ emotion_labels = [pred['label'] for pred in classifier("test")]
26
+ default_scores = [0] * len(emotion_labels)
27
+ default_df = pd.DataFrame({'label': emotion_labels, 'score': default_scores})
28
+
29
+ # Error DataFrame for BarPlot
30
+ error_df = pd.DataFrame({'label': ['error'], 'score': [1.0]})
31
+
32
  def predict_emotion(text):
33
  """Predict emotions from text and return formatted results."""
34
  try:
 
 
35
  if not text:
36
  logger.warning("Empty text received")
37
+ return gr.BarPlot.update(value=default_df)
 
 
 
 
 
38
 
39
+ # Get predictions and create DataFrame
40
  logger.info("Running prediction...")
41
  predictions = classifier(text)[0]
42
+ df = pd.DataFrame(predictions)
 
 
 
 
 
 
 
43
  logger.info(f"Processed scores:\n{df}")
44
 
45
  return gr.BarPlot.update(value=df)
46
 
47
  except Exception as e:
48
  logger.error(f"Error in prediction: {str(e)}")
49
+ return gr.BarPlot.update(value=error_df)
 
 
 
 
 
50
 
51
  # Create the Gradio interface
52
  demo = gr.Interface(
 
57
  lines=4
58
  ),
59
  outputs=gr.BarPlot(
60
+ value=default_df,
 
 
 
61
  x="label",
62
  y="score",
63
  title="Emotion Probabilities",
64
+ color=["#ff6f61", "#6b5b95", "#88b04b", "#f7cac9", "#92a8d1", "#f7786b"], # Dynamic color
65
  height=400,
66
  vertical=True
67
  ),
 
78
  )
79
 
80
  if __name__ == "__main__":
81
+ demo.launch(debug=True)