douglasgoodwin commited on
Commit
5ac858c
·
verified ·
1 Parent(s): 6962b0d

dump the bar plot

Browse files
Files changed (1) hide show
  1. app.py +9 -32
app.py CHANGED
@@ -21,41 +21,26 @@ classifier = pipeline(
21
  )
22
  logger.info("Pipeline initialized successfully")
23
 
24
- # Default DataFrame for BarPlot
25
- emotion_labels = [pred['label'] for pred in classifier("test input")]
26
- default_scores = [0] * len(emotion_labels)
27
- default_df = pd.DataFrame({'label': emotion_labels, 'score': default_scores})
28
-
29
- # Error DataFrame for BarPlot
30
- error_df = pd.DataFrame({'label': ['error'], 'score': [1.0]})
31
-
32
  def predict_emotion(text):
33
  """Predict emotions from text and return formatted results."""
34
  try:
35
  if not text:
36
  logger.warning("Empty text received")
37
- # Return default DataFrame directly
38
- return default_df
39
-
40
  # Get predictions and handle result structure
41
  logger.info("Running prediction...")
42
  predictions = classifier(text)
43
  logger.debug(f"Raw predictions output: {predictions}")
44
 
45
- # Ensure predictions are a list of dictionaries
46
- if isinstance(predictions, list) and len(predictions) > 0 and all(isinstance(item, dict) for item in predictions):
47
- df = pd.DataFrame(predictions) # Convert list of dictionaries to DataFrame
48
- else:
49
- logger.error(f"Unexpected predictions format: {predictions}")
50
- return error_df
51
-
52
- logger.info(f"Processed scores:\n{df}")
53
- return df
54
 
55
  except Exception as e:
56
  logger.error(f"Error in prediction: {str(e)}")
57
- # Return error DataFrame directly
58
- return error_df
59
 
60
  # Create the Gradio interface
61
  demo = gr.Interface(
@@ -65,16 +50,8 @@ demo = gr.Interface(
65
  label="Input Text",
66
  lines=4
67
  ),
68
- outputs=gr.BarPlot(
69
- value=default_df,
70
- x="label",
71
- y="score",
72
- title="Emotion Probabilities",
73
- color=["#ff6f61", "#6b5b95", "#88b04b", "#f7cac9", "#92a8d1", "#f7786b"], # Dynamic colors
74
- height=400,
75
- vertical=True
76
- ),
77
- title="Emotion Detection with DistilBERT",
78
  description="This app uses the DistilBERT model fine-tuned for emotion detection. Enter any text to analyze its emotional content.",
79
  examples=[
80
  "I am so happy to see you!",
 
21
  )
22
  logger.info("Pipeline initialized successfully")
23
 
 
 
 
 
 
 
 
 
24
  def predict_emotion(text):
25
  """Predict emotions from text and return formatted results."""
26
  try:
27
  if not text:
28
  logger.warning("Empty text received")
29
+ return {}
30
+
 
31
  # Get predictions and handle result structure
32
  logger.info("Running prediction...")
33
  predictions = classifier(text)
34
  logger.debug(f"Raw predictions output: {predictions}")
35
 
36
+ # Process predictions into a dict for display
37
+ scores = {item['label']: item['score'] for item in predictions}
38
+ logger.info(f"Processed scores: {scores}")
39
+ return scores
 
 
 
 
 
40
 
41
  except Exception as e:
42
  logger.error(f"Error in prediction: {str(e)}")
43
+ return {"error": "An error occurred during emotion prediction"}
 
44
 
45
  # Create the Gradio interface
46
  demo = gr.Interface(
 
50
  label="Input Text",
51
  lines=4
52
  ),
53
+ outputs=gr.JSON(), # Display the scores in a JSON format
54
+ title="CREATIVE MACHINES: Emotion Detection with DistilBERT",
 
 
 
 
 
 
 
 
55
  description="This app uses the DistilBERT model fine-tuned for emotion detection. Enter any text to analyze its emotional content.",
56
  examples=[
57
  "I am so happy to see you!",