Vera-ZWY commited on
Commit
792cbf1
·
verified ·
1 Parent(s): 96c3d05

Upload 2 files

Browse files
Files changed (2) hide show
  1. config.json +11 -0
  2. hydra_model.py +632 -0
config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "HydraModel"
4
+ ],
5
+ "backbone_model_name": "answerdotai/ModernBERT-base",
6
+ "model_type": "hydra",
7
+ "num_of_head": 7,
8
+ "output_size": 1,
9
+ "torch_dtype": "float32",
10
+ "transformers_version": "4.48.3"
11
+ }
hydra_model.py ADDED
@@ -0,0 +1,632 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import PreTrainedModel, PretrainedConfig, AutoModel, AutoTokenizer
4
+ from transformers import AutoModelForSequenceClassification
5
+ from transformers.modeling_outputs import SequenceClassifierOutput
6
+ from typing import Optional, List, Dict, Union, Tuple
7
+ from huggingface_hub import HfApi
8
+ import os
9
+ import json
10
+
11
+
12
+ class HydraConfig(PretrainedConfig):
13
+ """Configuration class for Hydra model."""
14
+
15
+ model_type = "hydra"
16
+
17
+ def __init__(
18
+ self,
19
+ backbone_model_name: str = "answerdotai/ModernBERT-base",
20
+ num_of_heads: int = 7,
21
+ hidden_size: int = 768,
22
+ output_size: int = 1,
23
+ label_dict: Dict[str, int] = None,
24
+ threshold: float = 0.5,
25
+ **kwargs
26
+ ):
27
+ super().__init__(**kwargs)
28
+ self.backbone_model_name = backbone_model_name
29
+ self.num_of_heads = num_of_heads
30
+ self.hidden_size = hidden_size
31
+ self.output_size = output_size
32
+ self.label_dict = label_dict if label_dict else {}
33
+ self.threshold = threshold
34
+
35
+
36
+ # We'll use the standard SequenceClassifierOutput instead of a custom output class
37
+
38
+
39
+ class HydraForSequenceClassification(PreTrainedModel):
40
+ """
41
+ Hydra model for sequence classification with multiple heads.
42
+
43
+ This model can be loaded with the `AutoModelForSequenceClassification` class.
44
+ """
45
+
46
+ config_class = HydraConfig
47
+ _auto_class = "AutoModelForSequenceClassification"
48
+
49
+ def __init__(self, config):
50
+ super().__init__(config)
51
+ self.config = config
52
+
53
+ # Load backbone
54
+ self.backbone = AutoModel.from_pretrained(config.backbone_model_name)
55
+
56
+ # Initialize the heads
57
+ self.heads = nn.ModuleList([
58
+ self.get_classifier(config.hidden_size, config.output_size)
59
+ for _ in range(config.num_of_heads)
60
+ ])
61
+
62
+ # Initialize weights
63
+ self.post_init()
64
+
65
+ def weights_init(self, m):
66
+ if isinstance(m, nn.Linear):
67
+ nn.init.kaiming_uniform_(m.weight.data)
68
+
69
+ def get_classifier(self, input_size, output_size):
70
+ mlp = nn.Sequential(
71
+ nn.Linear(in_features=input_size, out_features=input_size, bias=True),
72
+ nn.Linear(in_features=input_size, out_features=output_size, bias=True),
73
+ )
74
+
75
+ # Apply weight initialization
76
+ for module in mlp:
77
+ if isinstance(module, nn.Linear):
78
+ self.weights_init(module)
79
+
80
+ return mlp
81
+
82
+ def forward(
83
+ self,
84
+ input_ids=None,
85
+ attention_mask=None,
86
+ token_type_ids=None,
87
+ position_ids=None,
88
+ head_mask=None,
89
+ inputs_embeds=None,
90
+ labels=None,
91
+ output_attentions=None,
92
+ output_hidden_states=None,
93
+ return_dict=None,
94
+ ):
95
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
96
+
97
+ # Get embeddings from backbone
98
+ outputs = self.backbone(
99
+ input_ids=input_ids,
100
+ attention_mask=attention_mask,
101
+ token_type_ids=token_type_ids,
102
+ position_ids=position_ids,
103
+ head_mask=head_mask,
104
+ inputs_embeds=inputs_embeds,
105
+ output_attentions=output_attentions,
106
+ output_hidden_states=output_hidden_states,
107
+ return_dict=return_dict
108
+ )
109
+
110
+ # Mean pooling
111
+ token_embeddings = outputs[0]
112
+ mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size())
113
+ sum_embeddings = torch.sum(token_embeddings * mask_expanded, 1)
114
+ sum_mask = torch.sum(mask_expanded, 1)
115
+ mean_embeddings = sum_embeddings / sum_mask
116
+
117
+ # Apply each head
118
+ head_outputs = [head(mean_embeddings) for head in self.heads]
119
+ logits = torch.cat(head_outputs, dim=-1)
120
+
121
+ # Calculate loss if labels provided
122
+ loss = None
123
+ if labels is not None:
124
+ # You would implement your loss function here
125
+ # For now, we'll just use a placeholder
126
+ loss = torch.tensor(0.0)
127
+
128
+ # Handle return format
129
+ if not return_dict:
130
+ output = (logits,)
131
+ if loss is not None:
132
+ output = (loss,) + output
133
+ return output + (outputs.hidden_states if hasattr(outputs, "hidden_states") else None,)
134
+
135
+ return SequenceClassifierOutput(
136
+ loss=loss,
137
+ logits=logits,
138
+ hidden_states=outputs.hidden_states if hasattr(outputs, "hidden_states") else None,
139
+ attentions=outputs.attentions if hasattr(outputs, "attentions") else None,
140
+ )
141
+
142
+ @classmethod
143
+ def convert_checkpoint_to_hf_model(cls,
144
+ checkpoint_path,
145
+ backbone_model_name="answerdotai/ModernBERT-base",
146
+ label_dict=None,
147
+ threshold=0.5,
148
+ save_directory=None):
149
+ """
150
+ Convert a checkpoint to a Hugging Face model.
151
+
152
+ Args:
153
+ checkpoint_path: Path to the checkpoint file
154
+ backbone_model_name: Name of the backbone model
155
+ label_dict: Dictionary mapping labels to indices
156
+ threshold: Threshold for classification
157
+ save_directory: Directory to save the model
158
+
159
+ Returns:
160
+ HydraForSequenceClassification: The converted model
161
+ """
162
+ # Load the checkpoint
163
+ checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
164
+
165
+ # Get backbone information
166
+ backbone = AutoModel.from_pretrained(backbone_model_name)
167
+ hidden_size = backbone.config.hidden_size
168
+
169
+ # Create config
170
+ config = HydraConfig(
171
+ backbone_model_name=backbone_model_name,
172
+ num_of_heads=len(label_dict) if label_dict else 1,
173
+ hidden_size=hidden_size,
174
+ output_size=1,
175
+ label_dict=label_dict,
176
+ threshold=threshold
177
+ )
178
+
179
+ # Create model with this config
180
+ model = cls(config)
181
+
182
+ # Load state dict
183
+ model.load_state_dict(checkpoint)
184
+
185
+ # Save if directory provided
186
+ if save_directory:
187
+ # Save model
188
+ model.save_pretrained(save_directory)
189
+
190
+ # Save tokenizer
191
+ tokenizer = AutoTokenizer.from_pretrained(backbone_model_name)
192
+ tokenizer.save_pretrained(save_directory)
193
+
194
+ # Save label dictionary in a special file
195
+ if label_dict:
196
+ with open(os.path.join(save_directory, "label_dict.json"), "w") as f:
197
+ json.dump(label_dict, f)
198
+
199
+ return model
200
+
201
+ def get_labels_from_logits(self, logits):
202
+ """
203
+ Convert logits to labels based on threshold.
204
+
205
+ Args:
206
+ logits: Tensor of shape (batch_size, num_labels)
207
+
208
+ Returns:
209
+ list: List of predicted labels for each sample
210
+ """
211
+ # Convert logits to probabilities
212
+ probabilities = torch.sigmoid(logits)
213
+
214
+ # Convert to binary predictions using threshold
215
+ predictions = (probabilities >= self.config.threshold).int()
216
+
217
+ # Map predictions to labels
218
+ predicted_labels = []
219
+ for i in range(predictions.shape[0]):
220
+ sample_labels = [
221
+ label for label, idx in self.config.label_dict.items()
222
+ if predictions[i, idx] == 1
223
+ ]
224
+
225
+ # Handle special cases based on the model type
226
+ if len(sample_labels) == 0:
227
+ # Look for the "None" or "Not" label based on whether we have Emotion/Anxiety/Anger models
228
+ for none_label in ["Emotionless", "Not Anxiety", "No Anger", "Not Anger"]:
229
+ if none_label in self.config.label_dict:
230
+ sample_labels.append(none_label)
231
+ break
232
+ elif len(sample_labels) > 1:
233
+ # Remove the "None" label if multiple labels are predicted
234
+ for none_label in ["Emotionless", "Not Anxiety", "No Anger", "Not Anger"]:
235
+ if none_label in sample_labels:
236
+ sample_labels.remove(none_label)
237
+ break
238
+
239
+ predicted_labels.append(sample_labels)
240
+
241
+ return predicted_labels
242
+
243
+
244
+ # Register Hydra with AutoModelForSequenceClassification
245
+ # Use the simpler registration method
246
+ from transformers.models.auto.configuration_auto import CONFIG_MAPPING
247
+ from transformers.models.auto.modeling_auto import MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
248
+
249
+ CONFIG_MAPPING.register("hydra", HydraConfig)
250
+ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.register(HydraConfig, HydraForSequenceClassification)
251
+
252
+
253
+ def convert_and_push_models_to_hub(
254
+ repo_id,
255
+ ekman_filename,
256
+ anxiety_filename,
257
+ staxi_filename,
258
+ anger_filename,
259
+ access_token
260
+ ):
261
+ """
262
+ Convert all checkpoint models and push them to the Hub
263
+ """
264
+ # Define label dictionaries
265
+ ekman_label_dict = {
266
+ "Anger": 0, "Disgust": 1, "Fear": 2, "Happiness": 3,
267
+ "Sadness": 4, "Surprise": 5, "Emotionless": 6
268
+ }
269
+
270
+ anxiety_label_dict = {
271
+ "GAD": 0, "Panic Disorder": 1, "Social Anxiety Disorder": 2,
272
+ "Specific Phobias": 3, "Agoraphobia": 4, "Separation Anxiety Disorder": 5,
273
+ "Selective Mutism": 6, "Not Anxiety": 7
274
+ }
275
+
276
+ staxi_label_dict = {
277
+ "State Anger": 0, "Trait Anger": 1, "Anger Expression-Out": 2,
278
+ "Anger Expression-In": 3, "Anger Control-Out": 4, "Anger Control-In": 5,
279
+ "No Anger": 6
280
+ }
281
+
282
+ anger_label_dict = {
283
+ "Passive Anger": 0, "Volatile Anger": 1, "Fear-Based Anger": 2,
284
+ "Frustration-Based Anger": 3, "Pain-Based Anger": 4, "Chronic Anger": 5,
285
+ "Manipulative Anger": 6, "Overwhelmed Anger": 7, "Physiological Anger": 8,
286
+ "Righteous Anger": 9, "Not Anger": 10
287
+ }
288
+
289
+ # Define thresholds
290
+ ekman_threshold = 0.5
291
+ anxiety_threshold = 0.4
292
+ staxi_threshold = 0.4
293
+ anger_threshold = 0.4
294
+
295
+ # Download checkpoints from original repo
296
+ from huggingface_hub import hf_hub_download
297
+ ekman_path = hf_hub_download(repo_id=repo_id, filename=ekman_filename, token=access_token)
298
+ anxiety_path = hf_hub_download(repo_id=repo_id, filename=anxiety_filename, token=access_token)
299
+ staxi_path = hf_hub_download(repo_id=repo_id, filename=staxi_filename, token=access_token)
300
+ anger_path = hf_hub_download(repo_id=repo_id, filename=anger_filename, token=access_token)
301
+
302
+ # New repo IDs for the models
303
+ username = repo_id.split('/')[0] # Assuming repo_id is in format "username/repo-name"
304
+ ekman_repo = f"{username}/hydra-ekman-emotions"
305
+ anxiety_repo = f"{username}/hydra-anxiety-disorders"
306
+ staxi_repo = f"{username}/hydra-staxi-anger"
307
+ anger_repo = f"{username}/hydra-anger-types"
308
+
309
+ # Convert and push each model
310
+ api = HfApi()
311
+
312
+ # Create temporary directories for the models
313
+ import tempfile
314
+ import shutil
315
+
316
+ # Ekman model
317
+ ekman_dir = tempfile.mkdtemp()
318
+ ekman_model = HydraForSequenceClassification.convert_checkpoint_to_hf_model(
319
+ ekman_path,
320
+ label_dict=ekman_label_dict,
321
+ threshold=ekman_threshold,
322
+ save_directory=ekman_dir
323
+ )
324
+
325
+ # Create a model card for the Ekman model
326
+ with open(os.path.join(ekman_dir, "README.md"), "w") as f:
327
+ f.write(f"""# Hydra Ekman Emotions Model
328
+
329
+ This model identifies Ekman's 6 basic emotions plus "Emotionless" in text.
330
+
331
+ ## Model Details
332
+
333
+ - **Model Type:** Hydra (Multi-headed classification model)
334
+ - **Backbone:** ModernBERT
335
+ - **Labels:** {list(ekman_label_dict.keys())}
336
+ - **Threshold:** {ekman_threshold}
337
+
338
+ ## Usage
339
+
340
+ ```python
341
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
342
+
343
+ # Load model and tokenizer
344
+ tokenizer = AutoTokenizer.from_pretrained("{ekman_repo}")
345
+ model = AutoModelForSequenceClassification.from_pretrained("{ekman_repo}")
346
+
347
+ # Preprocess text
348
+ text = "I'm feeling really happy today!"
349
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
350
+
351
+ # Get predictions
352
+ outputs = model(**inputs)
353
+ logits = outputs.logits
354
+
355
+ # Get labels (using the helper function)
356
+ predicted_labels = model.get_labels_from_logits(logits)
357
+ print(f"Predicted emotions: {', '.join(predicted_labels[0])}")
358
+ ```
359
+
360
+ ## License
361
+
362
+ This model is available for research and commercial use.
363
+ """)
364
+
365
+ # Push to Hub
366
+ api.create_repo(ekman_repo, exist_ok=True)
367
+ api.upload_folder(
368
+ folder_path=ekman_dir,
369
+ repo_id=ekman_repo,
370
+ token=access_token
371
+ )
372
+
373
+ # Cleanup
374
+ shutil.rmtree(ekman_dir)
375
+
376
+ # Repeat for other models (similar process)
377
+ # Anxiety model
378
+ anxiety_dir = tempfile.mkdtemp()
379
+ anxiety_model = HydraForSequenceClassification.convert_checkpoint_to_hf_model(
380
+ anxiety_path,
381
+ label_dict=anxiety_label_dict,
382
+ threshold=anxiety_threshold,
383
+ save_directory=anxiety_dir
384
+ )
385
+
386
+ # Create model card
387
+ with open(os.path.join(anxiety_dir, "README.md"), "w") as f:
388
+ f.write(f"""# Hydra Anxiety Disorders Model
389
+
390
+ This model identifies different types of anxiety disorders in text.
391
+
392
+ ## Model Details
393
+
394
+ - **Model Type:** Hydra (Multi-headed classification model)
395
+ - **Backbone:** ModernBERT
396
+ - **Labels:** {list(anxiety_label_dict.keys())}
397
+ - **Threshold:** {anxiety_threshold}
398
+
399
+ ## Usage
400
+
401
+ ```python
402
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
403
+
404
+ # Load model and tokenizer
405
+ tokenizer = AutoTokenizer.from_pretrained("{anxiety_repo}")
406
+ model = AutoModelForSequenceClassification.from_pretrained("{anxiety_repo}")
407
+
408
+ # Example usage code
409
+ # ...
410
+ ```
411
+
412
+ ## License
413
+
414
+ This model is available for research and commercial use.
415
+ """)
416
+
417
+ # Push to Hub
418
+ api.create_repo(anxiety_repo, exist_ok=True)
419
+ api.upload_folder(
420
+ folder_path=anxiety_dir,
421
+ repo_id=anxiety_repo,
422
+ token=access_token
423
+ )
424
+
425
+ # Cleanup
426
+ shutil.rmtree(anxiety_dir)
427
+
428
+ # STAXI model
429
+ staxi_dir = tempfile.mkdtemp()
430
+ staxi_model = HydraForSequenceClassification.convert_checkpoint_to_hf_model(
431
+ staxi_path,
432
+ label_dict=staxi_label_dict,
433
+ threshold=staxi_threshold,
434
+ save_directory=staxi_dir
435
+ )
436
+
437
+ # Create model card
438
+ with open(os.path.join(staxi_dir, "README.md"), "w") as f:
439
+ f.write(f"""# Hydra STAXI Anger Model
440
+
441
+ This model identifies different types of anger based on the STAXI framework.
442
+
443
+ ## Model Details
444
+
445
+ - **Model Type:** Hydra (Multi-headed classification model)
446
+ - **Backbone:** ModernBERT
447
+ - **Labels:** {list(staxi_label_dict.keys())}
448
+ - **Threshold:** {staxi_threshold}
449
+
450
+ ## Usage
451
+
452
+ ```python
453
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
454
+
455
+ # Load model and tokenizer
456
+ tokenizer = AutoTokenizer.from_pretrained("{staxi_repo}")
457
+ model = AutoModelForSequenceClassification.from_pretrained("{staxi_repo}")
458
+
459
+ # Example usage code
460
+ # ...
461
+ ```
462
+
463
+ ## License
464
+
465
+ This model is available for research and commercial use.
466
+ """)
467
+
468
+ # Push to Hub
469
+ api.create_repo(staxi_repo, exist_ok=True)
470
+ api.upload_folder(
471
+ folder_path=staxi_dir,
472
+ repo_id=staxi_repo,
473
+ token=access_token
474
+ )
475
+
476
+ # Cleanup
477
+ shutil.rmtree(staxi_dir)
478
+
479
+ # Anger model
480
+ anger_dir = tempfile.mkdtemp()
481
+ anger_model = HydraForSequenceClassification.convert_checkpoint_to_hf_model(
482
+ anger_path,
483
+ label_dict=anger_label_dict,
484
+ threshold=anger_threshold,
485
+ save_directory=anger_dir
486
+ )
487
+
488
+ # Create model card
489
+ with open(os.path.join(anger_dir, "README.md"), "w") as f:
490
+ f.write(f"""# Hydra Anger Types Model
491
+
492
+ This model identifies different types of anger expressions in text.
493
+
494
+ ## Model Details
495
+
496
+ - **Model Type:** Hydra (Multi-headed classification model)
497
+ - **Backbone:** ModernBERT
498
+ - **Labels:** {list(anger_label_dict.keys())}
499
+ - **Threshold:** {anger_threshold}
500
+
501
+ ## Usage
502
+
503
+ ```python
504
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
505
+
506
+ # Load model and tokenizer
507
+ tokenizer = AutoTokenizer.from_pretrained("{anger_repo}")
508
+ model = AutoModelForSequenceClassification.from_pretrained("{anger_repo}")
509
+
510
+ # Example usage code
511
+ # ...
512
+ ```
513
+
514
+ ## License
515
+
516
+ This model is available for research and commercial use.
517
+ """)
518
+
519
+ # Push to Hub
520
+ api.create_repo(anger_repo, exist_ok=True)
521
+ api.upload_folder(
522
+ folder_path=anger_dir,
523
+ repo_id=anger_repo,
524
+ token=access_token
525
+ )
526
+
527
+ # Cleanup
528
+ shutil.rmtree(anger_dir)
529
+
530
+ # Return the repo names for reference
531
+ return {
532
+ "ekman_model": ekman_repo,
533
+ "anxiety_model": anxiety_repo,
534
+ "staxi_model": staxi_repo,
535
+ "anger_model": anger_repo
536
+ }
537
+
538
+
539
+ # Example helper function to use with the standard Hugging Face models
540
+ def classify_text(model_name, text):
541
+ """
542
+ Classify text using a standard Hugging Face model loading pattern.
543
+
544
+ Args:
545
+ model_name: Name of the model on Hugging Face
546
+ text: Text to classify
547
+
548
+ Returns:
549
+ list: Predicted labels
550
+ """
551
+ # Load model and tokenizer using Auto classes
552
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
553
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
554
+
555
+ # Preprocess the input text
556
+ encoded_input = tokenizer(
557
+ text,
558
+ padding="max_length",
559
+ truncation=True,
560
+ max_length=1024,
561
+ return_tensors="pt"
562
+ )
563
+
564
+ # Set model to evaluation mode
565
+ model.eval()
566
+
567
+ # Run inference
568
+ with torch.no_grad():
569
+ outputs = model(**encoded_input)
570
+ logits = outputs.logits
571
+
572
+ # Get predicted labels
573
+ predicted_labels = model.get_labels_from_logits(logits)
574
+
575
+ return predicted_labels[0] # Return first sample's labels
576
+
577
+
578
+ # Example of how to process a batch using standard HF patterns
579
+ def process_dataframe(df, model_name, text_column1, text_column2=None):
580
+ """
581
+ Process a DataFrame with a standard Hugging Face model.
582
+
583
+ Args:
584
+ df: DataFrame to process
585
+ model_name: Name of the model on Hugging Face
586
+ text_column1: Name of the first text column
587
+ text_column2: Name of the second text column (optional)
588
+
589
+ Returns:
590
+ list: List of labels for each row
591
+ """
592
+ # Load model and tokenizer
593
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
594
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
595
+
596
+ results = []
597
+
598
+ for _, row in df.iterrows():
599
+ # Skip rows with missing values
600
+ if pd.isnull(row[text_column1]) or (text_column2 and pd.isnull(row[text_column2])):
601
+ results.append(None)
602
+ continue
603
+
604
+ # Prepare text input
605
+ if text_column2:
606
+ text = f"{row[text_column1]} [SEP] {row[text_column2]}"
607
+ else:
608
+ text = row[text_column1]
609
+
610
+ # Skip special tokens
611
+ if text_column2 and row[text_column2] in ["[removed]", "[deleted]"]:
612
+ results.append(None)
613
+ continue
614
+
615
+ # Classify text
616
+ encoded_input = tokenizer(
617
+ text,
618
+ padding="max_length",
619
+ truncation=True,
620
+ max_length=1024,
621
+ return_tensors="pt"
622
+ )
623
+
624
+ # Run inference
625
+ model.eval()
626
+ with torch.no_grad():
627
+ outputs = model(**encoded_input)
628
+ logits = outputs.logits
629
+ predicted_labels = model.get_labels_from_logits(logits)
630
+ results.append(", ".join(predicted_labels[0]))
631
+
632
+ return results