Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -90,7 +90,14 @@ boundary_model.eval()
|
|
| 90 |
|
| 91 |
@spaces.GPU
|
| 92 |
def predict_boundary_health(text):
|
| 93 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
try:
|
| 95 |
inputs = boundary_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
|
| 96 |
inputs = {k: v.to(device) for k, v in inputs.items()}
|
|
@@ -98,32 +105,50 @@ def predict_boundary_health(text):
|
|
| 98 |
outputs = boundary_model(**inputs)
|
| 99 |
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
|
| 100 |
|
| 101 |
-
# Return the actual prediction (0 or
|
| 102 |
predicted_class = torch.argmax(predictions, dim=-1).item()
|
| 103 |
-
|
|
|
|
| 104 |
|
| 105 |
except Exception as e:
|
| 106 |
logger.error(f"Error in boundary prediction: {e}")
|
| 107 |
-
return 0 # Return unhealthy on error
|
| 108 |
|
| 109 |
-
def get_boundary_assessment(text, prediction):
|
| 110 |
-
"""Get boundary assessment based on
|
| 111 |
-
if prediction ==
|
| 112 |
return {
|
| 113 |
'assessment': 'healthy',
|
| 114 |
-
'label': '
|
| 115 |
-
'confidence':
|
| 116 |
-
'description': 'This communication shows healthy boundary setting',
|
| 117 |
'recommendations': ['Continue this respectful communication approach']
|
| 118 |
}
|
| 119 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
return {
|
| 121 |
'assessment': 'unhealthy',
|
| 122 |
-
'label': '
|
| 123 |
-
'confidence':
|
| 124 |
-
'description': 'Communication shows
|
| 125 |
-
'recommendations': ['
|
| 126 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
# Constants and Labels
|
| 128 |
LABELS = [
|
| 129 |
"recovery phase", "control", "gaslighting", "guilt tripping", "dismissiveness",
|
|
|
|
| 90 |
|
| 91 |
@spaces.GPU
|
| 92 |
def predict_boundary_health(text):
|
| 93 |
+
"""
|
| 94 |
+
Predict boundary health for given text
|
| 95 |
+
Returns:
|
| 96 |
+
- 0 for Respected (healthy)
|
| 97 |
+
- 1 for Violated (unhealthy)
|
| 98 |
+
- 2 for Dismissed (unhealthy)
|
| 99 |
+
- 3 for Manipulative (unhealthy)
|
| 100 |
+
"""
|
| 101 |
try:
|
| 102 |
inputs = boundary_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
|
| 103 |
inputs = {k: v.to(device) for k, v in inputs.items()}
|
|
|
|
| 105 |
outputs = boundary_model(**inputs)
|
| 106 |
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
|
| 107 |
|
| 108 |
+
# Return the actual prediction (0, 1, 2, or 3)
|
| 109 |
predicted_class = torch.argmax(predictions, dim=-1).item()
|
| 110 |
+
confidence = predictions[0][predicted_class].item()
|
| 111 |
+
return predicted_class, confidence
|
| 112 |
|
| 113 |
except Exception as e:
|
| 114 |
logger.error(f"Error in boundary prediction: {e}")
|
| 115 |
+
return 1, 1.0 # Return Violated (unhealthy) with full confidence on error
|
| 116 |
|
| 117 |
+
def get_boundary_assessment(text, prediction, confidence=1.0):
|
| 118 |
+
"""Get boundary assessment based on the 4-category prediction"""
|
| 119 |
+
if prediction == 0: # Respected (healthy)
|
| 120 |
return {
|
| 121 |
'assessment': 'healthy',
|
| 122 |
+
'label': 'Respected Boundary',
|
| 123 |
+
'confidence': confidence,
|
| 124 |
+
'description': 'This communication shows healthy boundary setting with mutual respect',
|
| 125 |
'recommendations': ['Continue this respectful communication approach']
|
| 126 |
}
|
| 127 |
+
elif prediction == 1: # Violated (unhealthy)
|
| 128 |
+
return {
|
| 129 |
+
'assessment': 'unhealthy',
|
| 130 |
+
'label': 'Violated Boundary',
|
| 131 |
+
'confidence': confidence,
|
| 132 |
+
'description': 'Communication shows boundary violation patterns',
|
| 133 |
+
'recommendations': ['Acknowledge the boundary violation', 'Use "I" statements instead of accusations', 'Focus on respectful communication']
|
| 134 |
+
}
|
| 135 |
+
elif prediction == 2: # Dismissed (unhealthy)
|
| 136 |
return {
|
| 137 |
'assessment': 'unhealthy',
|
| 138 |
+
'label': 'Dismissed Boundary',
|
| 139 |
+
'confidence': confidence,
|
| 140 |
+
'description': 'Communication shows boundary dismissal patterns',
|
| 141 |
+
'recommendations': ['Recognize and validate boundaries', 'Avoid minimizing others\' concerns', 'Practice active listening']
|
| 142 |
}
|
| 143 |
+
else: # Manipulative (unhealthy) - prediction == 3
|
| 144 |
+
return {
|
| 145 |
+
'assessment': 'unhealthy',
|
| 146 |
+
'label': 'Manipulative Boundary',
|
| 147 |
+
'confidence': confidence,
|
| 148 |
+
'description': 'Communication shows manipulative boundary patterns',
|
| 149 |
+
'recommendations': ['Avoid manipulation tactics', 'Communicate needs directly', 'Respect others\' autonomy']
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
# Constants and Labels
|
| 153 |
LABELS = [
|
| 154 |
"recovery phase", "control", "gaslighting", "guilt tripping", "dismissiveness",
|