jl commited on
Commit
e85f4ac
·
1 Parent(s): 95c1631

fix: apply batched processing

Browse files
Files changed (2) hide show
  1. src/app.py +2 -2
  2. src/hatespeech_model.py +148 -0
src/app.py CHANGED
@@ -2,7 +2,7 @@ import gc
2
  import re
3
 
4
  import streamlit as st
5
- from hatespeech_model import predict_hatespeech, load_model_from_hf, predict_hatespeech_from_file, get_rationale_from_mistral, preprocess_rationale_mistral
6
  import plotly.graph_objects as go
7
  import plotly.express as px
8
  import pandas as pd
@@ -413,7 +413,7 @@ if classify_button:
413
  # Run both models on the file
414
  # base_result = predict_hatespeech_from_file(...) # Base model
415
  # enhanced_result = predict_hatespeech_from_file(...) # Enhanced model
416
- enhanced_result = predict_hatespeech_from_file(
417
  text_list=file_content['text'].tolist(),
418
  rationale_list=file_content['CF_Rationales'].tolist(),
419
  true_label=file_content['label'].tolist(),
 
2
  import re
3
 
4
  import streamlit as st
5
+ from hatespeech_model import predict_hatespeech, load_model_from_hf, predict_hatespeech_from_file, get_rationale_from_mistral, preprocess_rationale_mistral, predict_hatespeech_from_file_batched
6
  import plotly.graph_objects as go
7
  import plotly.express as px
8
  import pandas as pd
 
413
  # Run both models on the file
414
  # base_result = predict_hatespeech_from_file(...) # Base model
415
  # enhanced_result = predict_hatespeech_from_file(...) # Enhanced model
416
+ enhanced_result = predict_hatespeech_from_file_batched(
417
  text_list=file_content['text'].tolist(),
418
  rationale_list=file_content['CF_Rationales'].tolist(),
419
  true_label=file_content['label'].tolist(),
src/hatespeech_model.py CHANGED
@@ -624,6 +624,154 @@ def predict_hatespeech_from_file(
624
  'runtime': runtime,
625
  'all_probabilities': all_probs.tolist()
626
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
627
 
628
  def predict_hatespeech(text, rationale, model, tokenizer_hatebert, tokenizer_rationale, config, device, model_type="altered"):
629
 
 
624
  'runtime': runtime,
625
  'all_probabilities': all_probs.tolist()
626
  }
627
+
628
+ def predict_hatespeech_from_file_batched(
629
+ text_list,
630
+ rationale_list,
631
+ true_label,
632
+ model,
633
+ tokenizer_hatebert,
634
+ tokenizer_rationale,
635
+ config,
636
+ device,
637
+ model_type="altered",
638
+ batch_size=16
639
+ ):
640
+
641
+ print(f"\nStarting batched inference for model: {type(model).__name__}")
642
+
643
+ predictions = []
644
+ all_probs = []
645
+ cpu_percent_list = []
646
+ memory_percent_list = []
647
+
648
+ process = psutil.Process(os.getpid())
649
+ max_length = config.get('max_length', 128)
650
+
651
+ if torch.cuda.is_available():
652
+ torch.cuda.synchronize()
653
+
654
+ # warmup
655
+ with torch.no_grad():
656
+ _ = predict_text(
657
+ text=text_list[0],
658
+ rationale=rationale_list[0],
659
+ model=model,
660
+ tokenizer_hatebert=tokenizer_hatebert,
661
+ tokenizer_rationale=tokenizer_rationale,
662
+ device=device,
663
+ max_length=max_length,
664
+ model_type=model_type
665
+ )
666
+
667
+ if torch.cuda.is_available():
668
+ torch.cuda.synchronize()
669
+
670
+ start_time = time()
671
+
672
+ # Process in batches
673
+ for batch_start in range(0, len(text_list), batch_size):
674
+ batch_end = min(batch_start + batch_size, len(text_list))
675
+ batch_texts = text_list[batch_start:batch_end]
676
+ batch_rationales = rationale_list[batch_start:batch_end]
677
+
678
+ # Tokenize batch
679
+ main_batch_inputs = tokenizer_hatebert(
680
+ batch_texts,
681
+ max_length=max_length,
682
+ padding="max_length",
683
+ truncation=True,
684
+ return_tensors="pt"
685
+ )
686
+
687
+ rationale_batch_inputs = tokenizer_rationale(
688
+ [r if r else t for r, t in zip(batch_rationales, batch_texts)],
689
+ max_length=max_length,
690
+ padding="max_length",
691
+ truncation=True,
692
+ return_tensors="pt"
693
+ )
694
+
695
+ # Move to device
696
+ batch_input_ids = main_batch_inputs["input_ids"].to(device)
697
+ batch_attention_mask = main_batch_inputs["attention_mask"].to(device)
698
+ batch_add_input_ids = rationale_batch_inputs["input_ids"].to(device)
699
+ batch_add_attention_mask = rationale_batch_inputs["attention_mask"].to(device)
700
+
701
+ with torch.no_grad():
702
+ if model_type.lower() == "base":
703
+ batch_logits = model(
704
+ batch_input_ids,
705
+ batch_attention_mask,
706
+ batch_add_input_ids,
707
+ batch_add_attention_mask
708
+ )
709
+ batch_rationale_probs = None
710
+ else:
711
+ batch_outputs = model(
712
+ batch_input_ids,
713
+ batch_attention_mask,
714
+ batch_add_input_ids,
715
+ batch_add_attention_mask
716
+ )
717
+
718
+ if isinstance(batch_outputs, tuple) and len(batch_outputs) == 4:
719
+ batch_logits, batch_rationale_probs, _, _ = batch_outputs
720
+ else:
721
+ raise ValueError(f"Unexpected number of outputs from model: {len(batch_outputs)}")
722
+
723
+ batch_probs = F.softmax(batch_logits, dim=1)
724
+
725
+ if torch.isnan(batch_probs).any() or torch.isinf(batch_probs).any():
726
+ batch_probs = torch.ones_like(batch_logits) / batch_logits.size(1)
727
+
728
+ batch_predictions = batch_logits.argmax(dim=1).cpu().numpy()
729
+ batch_probabilities = batch_probs.cpu().numpy()
730
+
731
+ # Collect batch results
732
+ predictions.extend(batch_predictions.tolist())
733
+ all_probs.extend(batch_probabilities.tolist())
734
+
735
+ # Log metrics periodically
736
+ if batch_end % max(10, batch_size) == 0 or batch_end == len(text_list):
737
+ cpu_percent_list.append(process.cpu_percent())
738
+ memory_percent_list.append(process.memory_info().rss / 1024 / 1024)
739
+
740
+ if torch.cuda.is_available():
741
+ torch.cuda.synchronize()
742
+
743
+ runtime = time() - start_time
744
+
745
+ print(f"Batched inference completed for {type(model).__name__}")
746
+ print(f"Total runtime: {runtime:.4f} seconds")
747
+
748
+ all_probs = np.array(all_probs)
749
+
750
+ f1 = f1_score(true_label, predictions, zero_division=0)
751
+ accuracy = accuracy_score(true_label, predictions)
752
+ precision = precision_score(true_label, predictions, zero_division=0)
753
+ recall = recall_score(true_label, predictions, zero_division=0)
754
+ cm = confusion_matrix(true_label, predictions).tolist()
755
+
756
+ avg_cpu = sum(cpu_percent_list) / len(cpu_percent_list) if cpu_percent_list else 0
757
+ avg_memory = sum(memory_percent_list) / len(memory_percent_list) if memory_percent_list else 0
758
+ peak_memory = max(memory_percent_list) if memory_percent_list else 0
759
+ peak_cpu = max(cpu_percent_list) if cpu_percent_list else 0
760
+
761
+ return {
762
+ 'model_name': type(model).__name__,
763
+ 'f1_score': f1,
764
+ 'accuracy': accuracy,
765
+ 'precision': precision,
766
+ 'recall': recall,
767
+ 'confusion_matrix': cm,
768
+ 'cpu_usage': avg_cpu,
769
+ 'memory_usage': avg_memory,
770
+ 'peak_cpu_usage': peak_cpu,
771
+ 'peak_memory_usage': peak_memory,
772
+ 'runtime': runtime,
773
+ 'all_probabilities': all_probs.tolist()
774
+ }
775
 
776
  def predict_hatespeech(text, rationale, model, tokenizer_hatebert, tokenizer_rationale, config, device, model_type="altered"):
777