Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -235,30 +235,38 @@ def classify_emotion(text, classifier):
|
|
| 235 |
|
| 236 |
def get_embedding_for_text(text, tokenizer, model):
|
| 237 |
"""Get embedding for complete text."""
|
| 238 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 239 |
chunk_embeddings = []
|
| 240 |
|
| 241 |
-
for
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 246 |
truncation=True,
|
| 247 |
-
max_length
|
| 248 |
)
|
| 249 |
-
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
| 250 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 251 |
with torch.no_grad():
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
embedding = outputs[0][:, 0, :].cpu().numpy()
|
| 255 |
chunk_embeddings.append(embedding[0])
|
| 256 |
|
|
|
|
| 257 |
if chunk_embeddings:
|
| 258 |
-
|
| 259 |
-
weights = weights / weights.sum()
|
| 260 |
-
weighted_embedding = np.average(chunk_embeddings, axis=0, weights=weights)
|
| 261 |
-
return weighted_embedding
|
| 262 |
return np.zeros(model.config.hidden_size)
|
| 263 |
|
| 264 |
def format_topics(topic_model, topic_counts):
|
|
|
|
| 235 |
|
| 236 |
def get_embedding_for_text(text, tokenizer, model):
|
| 237 |
"""Get embedding for complete text."""
|
| 238 |
+
# First tokenize to get exact count
|
| 239 |
+
tokens = tokenizer.tokenize(text)
|
| 240 |
+
|
| 241 |
+
# Process in chunks of exactly 510 tokens (512 - 2 for CLS and SEP)
|
| 242 |
+
chunk_size = 510
|
| 243 |
chunk_embeddings = []
|
| 244 |
|
| 245 |
+
for i in range(0, len(tokens), chunk_size):
|
| 246 |
+
chunk = tokens[i:i + chunk_size]
|
| 247 |
+
# Convert tokens back to text
|
| 248 |
+
chunk_text = tokenizer.convert_tokens_to_string(chunk)
|
| 249 |
+
# Now encode with special tokens
|
| 250 |
+
encoded = tokenizer(
|
| 251 |
+
chunk_text,
|
| 252 |
+
return_tensors='pt',
|
| 253 |
+
max_length=512,
|
| 254 |
truncation=True,
|
| 255 |
+
padding='max_length'
|
| 256 |
)
|
|
|
|
| 257 |
|
| 258 |
+
# Move to device
|
| 259 |
+
encoded = {k: v.to(model.device) for k, v in encoded.items()}
|
| 260 |
+
|
| 261 |
+
# Get embedding
|
| 262 |
with torch.no_grad():
|
| 263 |
+
output = model(**encoded)
|
| 264 |
+
embedding = output[0][:, 0, :].cpu().numpy()
|
|
|
|
| 265 |
chunk_embeddings.append(embedding[0])
|
| 266 |
|
| 267 |
+
# Combine all chunk embeddings
|
| 268 |
if chunk_embeddings:
|
| 269 |
+
return np.mean(chunk_embeddings, axis=0)
|
|
|
|
|
|
|
|
|
|
| 270 |
return np.zeros(model.config.hidden_size)
|
| 271 |
|
| 272 |
def format_topics(topic_model, topic_counts):
|