File size: 11,432 Bytes
1319732 ca8e0c5 1319732 ca8e0c5 1319732 ca8e0c5 1319732 ca8e0c5 34873e0 ca8e0c5 1319732 ca8e0c5 1319732 34873e0 ca8e0c5 d0a5cbd ca8e0c5 34873e0 ca8e0c5 1319732 6e81e10 34873e0 1319732 ca8e0c5 1319732 ca8e0c5 1319732 e0f2f75 1319732 ca8e0c5 34873e0 ca8e0c5 34873e0 1319732 ca8e0c5 34873e0 1319732 34873e0 ca8e0c5 34873e0 ca8e0c5 34873e0 ca8e0c5 1319732 ca8e0c5 1319732 34873e0 ca8e0c5 34873e0 ca8e0c5 34873e0 ca8e0c5 ad480fe 34873e0 ad480fe 1b95ed9 34873e0 1b95ed9 34873e0 1b95ed9 ca8e0c5 1b95ed9 34873e0 1319732 ca8e0c5 1319732 34873e0 1319732 34873e0 1319732 34873e0 1319732 34873e0 1319732 34873e0 1319732 c4e871c 1319732 34873e0 1319732 34873e0 3a7a32b ca8e0c5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 |
import torch
import logging
import re
from typing import Dict, List, Any
from transformers import pipeline, AutoModelForQuestionAnswering, AutoTokenizer
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class EndpointHandler:
def __init__(self, path=""):
"""
Initialize the RECCON emotional trigger extraction model using native transformers.
Args:
path: Path to model directory (provided by HuggingFace Inference Endpoints)
"""
logger.info("Initializing RECCON Trigger Extraction endpoint...")
# Detect device (CUDA/CPU)
cuda_available = torch.cuda.is_available()
if not cuda_available:
logger.warning("GPU not detected. Running on CPU. Inference will be slower.")
# In 'pipeline', device is an integer (-1 for CPU, 0+ for GPU)
self.device_id = 0 if cuda_available else -1
# Determine model path
model_path = path if path and path != "." else "."
logger.info(f"Loading model from {model_path}...")
try:
# Load tokenizer and model explicitly to ensure correct loading
tokenizer = AutoTokenizer.from_pretrained(model_path)
model, loading_info = AutoModelForQuestionAnswering.from_pretrained(
model_path,
output_loading_info=True
)
logger.warning("RECCON load info - missing_keys: %s", loading_info.get("missing_keys"))
logger.warning("RECCON load info - unexpected_keys: %s", loading_info.get("unexpected_keys"))
logger.warning("RECCON load info - error_msgs: %s", loading_info.get("error_msgs"))
logger.warning("Loaded model class: %s", model.__class__.__name__)
logger.warning("Loaded model name_or_path: %s", getattr(model.config, "_name_or_path", None))
# Initialize the pipeline
# top_k=20 matches your previous 'n_best_size=20' logic
self.pipe = pipeline(
"question-answering",
model=model,
tokenizer=tokenizer,
device=self.device_id,
top_k=20,
handle_impossible_answer=False
)
logger.info("Model loaded successfully.")
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
# Question template (must match training)
self.question_template = (
"Extract the exact short phrase (<= 8 words) from the target "
"utterance that most strongly signals the emotion {emotion}. "
"Return only a substring of the target utterance."
)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Process inference request.
"""
# Extract inputs
inputs = data.pop("inputs", data)
# Normalize to list format
if isinstance(inputs, dict):
inputs = [inputs]
if not inputs:
return [{"error": "No inputs provided", "triggers": []}]
# Validate and format inputs for the pipeline
pipeline_inputs = []
valid_indices = []
for i, item in enumerate(inputs):
utterance = item.get("utterance", "").strip()
emotion = item.get("emotion", "")
if not utterance:
logger.warning(f"Empty utterance at index {i}")
continue
# Format as QA task
question = self.question_template.format(emotion=emotion)
# The pipeline expects a list of dicts with 'question' and 'context'
pipeline_inputs.append({
'question': question,
'context': utterance
})
valid_indices.append(i)
# Run prediction
results = []
if not pipeline_inputs:
# All inputs were invalid
for item in inputs:
results.append({
"utterance": item.get("utterance", ""),
"emotion": item.get("emotion", ""),
"error": "Missing or empty utterance",
"triggers": []
})
return results
try:
# Run inference (batch_size helps with multiple inputs)
predictions = self.pipe(pipeline_inputs, batch_size=8)
# If batch_size=1 or single input, pipeline might return a single list/dict
# We ensure it's a list of lists (since top_k > 1)
if isinstance(predictions, dict): # Single input result
predictions = [predictions] # Wrap in list
elif isinstance(predictions, list) and len(predictions) > 0 and isinstance(predictions[0], dict):
# This happens if we have multiple inputs but top_k=1 (which is not the case here),
# OR if we have a single input and top_k > 1.
# If we have multiple inputs and top_k > 1, it returns a list of lists.
if len(pipeline_inputs) == 1:
predictions = [predictions]
# If multiple inputs and list of dicts, that implies top_k=1.
# But we set top_k=20. So it should be list of lists.
logger.debug(f"Raw predictions: {predictions}")
# Post-process results
pred_idx = 0
for i, item in enumerate(inputs):
utterance = item.get("utterance", "").strip()
emotion = item.get("emotion", "")
if i not in valid_indices:
results.append({
"utterance": utterance,
"emotion": emotion,
"error": "Missing or empty utterance",
"triggers": []
})
else:
# Get prediction for this item
# Because top_k=20, 'current_preds' is a list of dicts: [{'answer': '...', 'score': ...}, ...]
current_preds = predictions[pred_idx]
# Ensure it is a list
if isinstance(current_preds, dict):
current_preds = [current_preds]
logger.info(
"RECCON raw spans (answer, score): %s",
[(p.get("answer"), p.get("score", 0.0), 3) for p in current_preds[:5]]
)
def is_good_span(ans: str) -> bool:
if not ans:
return False
a = ans.strip()
if len(a) < 3:
return False
# reject pure punctuation
if all(ch in ".,!?;:-—'\"()[]{}" for ch in a):
return False
# require at least one letter
if not any(ch.isalpha() for ch in a):
return False
return True
raw_answers = [p.get("answer", "") for p in current_preds]
raw_answers = [a for a in raw_answers if is_good_span(a)]
triggers = self._clean_spans(raw_answers, utterance)
results.append({
"utterance": utterance,
"emotion": emotion,
"triggers": triggers
})
pred_idx += 1
logger.debug(f"Cleaned results: {results}")
return results
except Exception as e:
logger.error(f"Model prediction failed: {e}")
return [{
"utterance": item.get("utterance", ""),
"emotion": item.get("emotion", ""),
"error": str(e),
"triggers": []
} for item in inputs]
def _clean_spans(self, spans: List[str], target_text: str) -> List[str]:
"""
Clean and filter extracted trigger spans.
(Logic preserved exactly as provided)
"""
target_text = target_text or ""
target_lower = target_text.lower()
def _norm(s: str) -> str:
s = (s or "").strip().lower()
s = re.sub(r"\s+", " ", s)
s = re.sub(r"^[^\w]+|[^\w]+$", "", s)
return s
def _extract_from_target(target: str, phrase_lower: str) -> str:
idx = target.lower().find(phrase_lower)
if idx >= 0:
return target[idx:idx+len(phrase_lower)]
return phrase_lower
STOP = {
"a", "an", "the", "and", "or", "but", "so", "to", "of", "in", "on", "at",
"with", "for", "from", "is", "am", "are", "was", "were", "be", "been",
"being", "i", "you", "he", "she", "it", "we", "they", "my", "your", "his",
"her", "their", "our", "me", "him", "her", "them", "this", "that", "these",
"those"
}
candidates = []
for s in spans:
s = (s or "").strip()
if not s:
continue
s_norm = _norm(s)
if not s_norm:
continue
if target_text and s_norm not in target_lower:
continue
tokens = s_norm.split()
if len(tokens) > 8 or len(s_norm) > 80:
continue
if len(tokens) == 1 and (tokens[0] in STOP or len(tokens[0]) <= 2):
continue
candidates.append({
"norm": s_norm,
"tokens": tokens,
"tok_len": len(tokens),
"char_len": len(s_norm)
})
# Prioritize short, focused emotional keywords (1-3 words)
short_candidates = [c for c in candidates if 1 <= c["tok_len"] <= 3]
if short_candidates:
candidates = short_candidates
# Sort by SHORTEST spans first (most focused keywords)
candidates.sort(key=lambda x: (x["tok_len"], x["char_len"]), reverse=False)
kept_norms = []
for c in list(candidates):
n = c["norm"]
if any(n in kn or kn in n for kn in kept_norms):
continue
kept_norms.append(n)
cleaned = [_extract_from_target(target_text, n) for n in kept_norms]
if not cleaned and spans:
tt_tokens = target_lower.split()
best = None
for s in spans:
words = [w for w in (s or '').lower().strip().split() if w]
for L in range(min(8, len(words)), 0, -1):
for i in range(len(words) - L + 1):
phrase = words[i:i+L]
for j in range(len(tt_tokens) - L + 1):
if tt_tokens[j:j+L] == phrase:
cand = " ".join(phrase)
best = cand
break
if best:
break
if best:
break
if best:
return [_extract_from_target(target_text, best)]
return cleaned[:3] |