SamanthaStorm commited on
Commit
a239b18
·
verified ·
1 Parent(s): b3ef4a8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -15
app.py CHANGED
@@ -90,7 +90,14 @@ boundary_model.eval()
90
 
91
  @spaces.GPU
92
  def predict_boundary_health(text):
93
- """Predict boundary health for given text - returns 1 for healthy, 0 for unhealthy"""
 
 
 
 
 
 
 
94
  try:
95
  inputs = boundary_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
96
  inputs = {k: v.to(device) for k, v in inputs.items()}
@@ -98,32 +105,50 @@ def predict_boundary_health(text):
98
  outputs = boundary_model(**inputs)
99
  predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
100
 
101
- # Return the actual prediction (0 or 1)
102
  predicted_class = torch.argmax(predictions, dim=-1).item()
103
- return predicted_class # Returns 1 for healthy, 0 for unhealthy
 
104
 
105
  except Exception as e:
106
  logger.error(f"Error in boundary prediction: {e}")
107
- return 0 # Return unhealthy on error
108
 
109
- def get_boundary_assessment(text, prediction):
110
- """Get boundary assessment based on binary prediction"""
111
- if prediction == 1: # healthy
112
  return {
113
  'assessment': 'healthy',
114
- 'label': 'Healthy Boundary',
115
- 'confidence': 1.0,
116
- 'description': 'This communication shows healthy boundary setting',
117
  'recommendations': ['Continue this respectful communication approach']
118
  }
119
- else: # unhealthy (prediction == 0)
 
 
 
 
 
 
 
 
120
  return {
121
  'assessment': 'unhealthy',
122
- 'label': 'Unhealthy Boundary',
123
- 'confidence': 1.0,
124
- 'description': 'Communication shows unhealthy boundary patterns',
125
- 'recommendations': ['Use "I" statements instead of accusations', 'Focus on respectful communication']
126
  }
 
 
 
 
 
 
 
 
 
127
  # Constants and Labels
128
  LABELS = [
129
  "recovery phase", "control", "gaslighting", "guilt tripping", "dismissiveness",
 
90
 
91
  @spaces.GPU
92
  def predict_boundary_health(text):
93
+ """
94
+ Predict boundary health for given text
95
+ Returns:
96
+ - 0 for Respected (healthy)
97
+ - 1 for Violated (unhealthy)
98
+ - 2 for Dismissed (unhealthy)
99
+ - 3 for Manipulative (unhealthy)
100
+ """
101
  try:
102
  inputs = boundary_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
103
  inputs = {k: v.to(device) for k, v in inputs.items()}
 
105
  outputs = boundary_model(**inputs)
106
  predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
107
 
108
+ # Return the actual prediction (0, 1, 2, or 3)
109
  predicted_class = torch.argmax(predictions, dim=-1).item()
110
+ confidence = predictions[0][predicted_class].item()
111
+ return predicted_class, confidence
112
 
113
  except Exception as e:
114
  logger.error(f"Error in boundary prediction: {e}")
115
+ return 1, 1.0 # Return Violated (unhealthy) with full confidence on error
116
 
117
+ def get_boundary_assessment(text, prediction, confidence=1.0):
118
+ """Get boundary assessment based on the 4-category prediction"""
119
+ if prediction == 0: # Respected (healthy)
120
  return {
121
  'assessment': 'healthy',
122
+ 'label': 'Respected Boundary',
123
+ 'confidence': confidence,
124
+ 'description': 'This communication shows healthy boundary setting with mutual respect',
125
  'recommendations': ['Continue this respectful communication approach']
126
  }
127
+ elif prediction == 1: # Violated (unhealthy)
128
+ return {
129
+ 'assessment': 'unhealthy',
130
+ 'label': 'Violated Boundary',
131
+ 'confidence': confidence,
132
+ 'description': 'Communication shows boundary violation patterns',
133
+ 'recommendations': ['Acknowledge the boundary violation', 'Use "I" statements instead of accusations', 'Focus on respectful communication']
134
+ }
135
+ elif prediction == 2: # Dismissed (unhealthy)
136
  return {
137
  'assessment': 'unhealthy',
138
+ 'label': 'Dismissed Boundary',
139
+ 'confidence': confidence,
140
+ 'description': 'Communication shows boundary dismissal patterns',
141
+ 'recommendations': ['Recognize and validate boundaries', 'Avoid minimizing others\' concerns', 'Practice active listening']
142
  }
143
+ else: # Manipulative (unhealthy) - prediction == 3
144
+ return {
145
+ 'assessment': 'unhealthy',
146
+ 'label': 'Manipulative Boundary',
147
+ 'confidence': confidence,
148
+ 'description': 'Communication shows manipulative boundary patterns',
149
+ 'recommendations': ['Avoid manipulation tactics', 'Communicate needs directly', 'Respect others\' autonomy']
150
+ }
151
+
152
  # Constants and Labels
153
  LABELS = [
154
  "recovery phase", "control", "gaslighting", "guilt tripping", "dismissiveness",