gregmerritt commited on
Commit
e66a248
Β·
verified Β·
1 Parent(s): 3d7fdee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -3
app.py CHANGED
@@ -80,16 +80,49 @@ templates = {
80
  }
81
 
82
  # Prediction logic
 
 
 
 
 
 
 
 
 
 
 
83
  def predict_emotion_and_icebreaker(image, tone):
84
  image = Image.fromarray(image).convert("RGB")
85
  image = transform(image).unsqueeze(0).to(device)
86
-
87
  with torch.no_grad():
88
  output = model(image)
89
  pred = output.argmax(dim=1).item()
90
-
91
  emotion = classes[pred]
92
- prompt = templates[tone][emotion]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
  # Generate icebreaker with better sampling
95
  generated = gen(
 
80
  }
81
 
82
  # Prediction logic
83
+ # Emotion to Emoji mapping
84
+ emotion_emojis = {
85
+ "Angry": "😠",
86
+ "Disgust": "🀒",
87
+ "Fear": "😨",
88
+ "Happy": "πŸ˜„",
89
+ "Sad": "😒",
90
+ "Surprise": "😲",
91
+ "Neutral": "😐"
92
+ }
93
+
94
  def predict_emotion_and_icebreaker(image, tone):
95
  image = Image.fromarray(image).convert("RGB")
96
  image = transform(image).unsqueeze(0).to(device)
97
+
98
  with torch.no_grad():
99
  output = model(image)
100
  pred = output.argmax(dim=1).item()
101
+
102
  emotion = classes[pred]
103
+ emoji = emotion_emojis.get(emotion, "πŸ™‚")
104
+
105
+ # Refined prompt for short, relevant output with emoji
106
+ prompt = templates[tone][emotion] + " Respond in one short sentence with an emoji."
107
+
108
+ result = gen(
109
+ prompt,
110
+ max_length=30,
111
+ num_return_sequences=1,
112
+ do_sample=True,
113
+ temperature=0.8,
114
+ top_p=0.9,
115
+ pad_token_id=50256
116
+ )[0]['generated_text']
117
+
118
+ # Clean generated text by stripping prompt
119
+ response = result[len(prompt):].strip()
120
+
121
+ # Ensure minimal valid output
122
+ if not response or len(response) < 5:
123
+ response = "Let's get the conversation going! πŸŽ‰"
124
+
125
+ return f"🧠 Emotion Detected: {emotion} {emoji}\nπŸ’¬ Icebreaker ({tone}):\n{response}"
126
 
127
  # Generate icebreaker with better sampling
128
  generated = gen(