Spaces:
Sleeping
Sleeping
jl commited on
Commit ·
e85f4ac
1
Parent(s): 95c1631
fix: apply batched processing
Browse files- src/app.py +2 -2
- src/hatespeech_model.py +148 -0
src/app.py
CHANGED
|
@@ -2,7 +2,7 @@ import gc
|
|
| 2 |
import re
|
| 3 |
|
| 4 |
import streamlit as st
|
| 5 |
-
from hatespeech_model import predict_hatespeech, load_model_from_hf, predict_hatespeech_from_file, get_rationale_from_mistral, preprocess_rationale_mistral
|
| 6 |
import plotly.graph_objects as go
|
| 7 |
import plotly.express as px
|
| 8 |
import pandas as pd
|
|
@@ -413,7 +413,7 @@ if classify_button:
|
|
| 413 |
# Run both models on the file
|
| 414 |
# base_result = predict_hatespeech_from_file(...) # Base model
|
| 415 |
# enhanced_result = predict_hatespeech_from_file(...) # Enhanced model
|
| 416 |
-
enhanced_result =
|
| 417 |
text_list=file_content['text'].tolist(),
|
| 418 |
rationale_list=file_content['CF_Rationales'].tolist(),
|
| 419 |
true_label=file_content['label'].tolist(),
|
|
|
|
| 2 |
import re
|
| 3 |
|
| 4 |
import streamlit as st
|
| 5 |
+
from hatespeech_model import predict_hatespeech, load_model_from_hf, predict_hatespeech_from_file, get_rationale_from_mistral, preprocess_rationale_mistral, predict_hatespeech_from_file_batched
|
| 6 |
import plotly.graph_objects as go
|
| 7 |
import plotly.express as px
|
| 8 |
import pandas as pd
|
|
|
|
| 413 |
# Run both models on the file
|
| 414 |
# base_result = predict_hatespeech_from_file(...) # Base model
|
| 415 |
# enhanced_result = predict_hatespeech_from_file(...) # Enhanced model
|
| 416 |
+
enhanced_result = predict_hatespeech_from_file_batched(
|
| 417 |
text_list=file_content['text'].tolist(),
|
| 418 |
rationale_list=file_content['CF_Rationales'].tolist(),
|
| 419 |
true_label=file_content['label'].tolist(),
|
src/hatespeech_model.py
CHANGED
|
@@ -624,6 +624,154 @@ def predict_hatespeech_from_file(
|
|
| 624 |
'runtime': runtime,
|
| 625 |
'all_probabilities': all_probs.tolist()
|
| 626 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 627 |
|
| 628 |
def predict_hatespeech(text, rationale, model, tokenizer_hatebert, tokenizer_rationale, config, device, model_type="altered"):
|
| 629 |
|
|
|
|
| 624 |
'runtime': runtime,
|
| 625 |
'all_probabilities': all_probs.tolist()
|
| 626 |
}
|
| 627 |
+
|
| 628 |
+
def predict_hatespeech_from_file_batched(
|
| 629 |
+
text_list,
|
| 630 |
+
rationale_list,
|
| 631 |
+
true_label,
|
| 632 |
+
model,
|
| 633 |
+
tokenizer_hatebert,
|
| 634 |
+
tokenizer_rationale,
|
| 635 |
+
config,
|
| 636 |
+
device,
|
| 637 |
+
model_type="altered",
|
| 638 |
+
batch_size=16
|
| 639 |
+
):
|
| 640 |
+
|
| 641 |
+
print(f"\nStarting batched inference for model: {type(model).__name__}")
|
| 642 |
+
|
| 643 |
+
predictions = []
|
| 644 |
+
all_probs = []
|
| 645 |
+
cpu_percent_list = []
|
| 646 |
+
memory_percent_list = []
|
| 647 |
+
|
| 648 |
+
process = psutil.Process(os.getpid())
|
| 649 |
+
max_length = config.get('max_length', 128)
|
| 650 |
+
|
| 651 |
+
if torch.cuda.is_available():
|
| 652 |
+
torch.cuda.synchronize()
|
| 653 |
+
|
| 654 |
+
# warmup
|
| 655 |
+
with torch.no_grad():
|
| 656 |
+
_ = predict_text(
|
| 657 |
+
text=text_list[0],
|
| 658 |
+
rationale=rationale_list[0],
|
| 659 |
+
model=model,
|
| 660 |
+
tokenizer_hatebert=tokenizer_hatebert,
|
| 661 |
+
tokenizer_rationale=tokenizer_rationale,
|
| 662 |
+
device=device,
|
| 663 |
+
max_length=max_length,
|
| 664 |
+
model_type=model_type
|
| 665 |
+
)
|
| 666 |
+
|
| 667 |
+
if torch.cuda.is_available():
|
| 668 |
+
torch.cuda.synchronize()
|
| 669 |
+
|
| 670 |
+
start_time = time()
|
| 671 |
+
|
| 672 |
+
# Process in batches
|
| 673 |
+
for batch_start in range(0, len(text_list), batch_size):
|
| 674 |
+
batch_end = min(batch_start + batch_size, len(text_list))
|
| 675 |
+
batch_texts = text_list[batch_start:batch_end]
|
| 676 |
+
batch_rationales = rationale_list[batch_start:batch_end]
|
| 677 |
+
|
| 678 |
+
# Tokenize batch
|
| 679 |
+
main_batch_inputs = tokenizer_hatebert(
|
| 680 |
+
batch_texts,
|
| 681 |
+
max_length=max_length,
|
| 682 |
+
padding="max_length",
|
| 683 |
+
truncation=True,
|
| 684 |
+
return_tensors="pt"
|
| 685 |
+
)
|
| 686 |
+
|
| 687 |
+
rationale_batch_inputs = tokenizer_rationale(
|
| 688 |
+
[r if r else t for r, t in zip(batch_rationales, batch_texts)],
|
| 689 |
+
max_length=max_length,
|
| 690 |
+
padding="max_length",
|
| 691 |
+
truncation=True,
|
| 692 |
+
return_tensors="pt"
|
| 693 |
+
)
|
| 694 |
+
|
| 695 |
+
# Move to device
|
| 696 |
+
batch_input_ids = main_batch_inputs["input_ids"].to(device)
|
| 697 |
+
batch_attention_mask = main_batch_inputs["attention_mask"].to(device)
|
| 698 |
+
batch_add_input_ids = rationale_batch_inputs["input_ids"].to(device)
|
| 699 |
+
batch_add_attention_mask = rationale_batch_inputs["attention_mask"].to(device)
|
| 700 |
+
|
| 701 |
+
with torch.no_grad():
|
| 702 |
+
if model_type.lower() == "base":
|
| 703 |
+
batch_logits = model(
|
| 704 |
+
batch_input_ids,
|
| 705 |
+
batch_attention_mask,
|
| 706 |
+
batch_add_input_ids,
|
| 707 |
+
batch_add_attention_mask
|
| 708 |
+
)
|
| 709 |
+
batch_rationale_probs = None
|
| 710 |
+
else:
|
| 711 |
+
batch_outputs = model(
|
| 712 |
+
batch_input_ids,
|
| 713 |
+
batch_attention_mask,
|
| 714 |
+
batch_add_input_ids,
|
| 715 |
+
batch_add_attention_mask
|
| 716 |
+
)
|
| 717 |
+
|
| 718 |
+
if isinstance(batch_outputs, tuple) and len(batch_outputs) == 4:
|
| 719 |
+
batch_logits, batch_rationale_probs, _, _ = batch_outputs
|
| 720 |
+
else:
|
| 721 |
+
raise ValueError(f"Unexpected number of outputs from model: {len(batch_outputs)}")
|
| 722 |
+
|
| 723 |
+
batch_probs = F.softmax(batch_logits, dim=1)
|
| 724 |
+
|
| 725 |
+
if torch.isnan(batch_probs).any() or torch.isinf(batch_probs).any():
|
| 726 |
+
batch_probs = torch.ones_like(batch_logits) / batch_logits.size(1)
|
| 727 |
+
|
| 728 |
+
batch_predictions = batch_logits.argmax(dim=1).cpu().numpy()
|
| 729 |
+
batch_probabilities = batch_probs.cpu().numpy()
|
| 730 |
+
|
| 731 |
+
# Collect batch results
|
| 732 |
+
predictions.extend(batch_predictions.tolist())
|
| 733 |
+
all_probs.extend(batch_probabilities.tolist())
|
| 734 |
+
|
| 735 |
+
# Log metrics periodically
|
| 736 |
+
if batch_end % max(10, batch_size) == 0 or batch_end == len(text_list):
|
| 737 |
+
cpu_percent_list.append(process.cpu_percent())
|
| 738 |
+
memory_percent_list.append(process.memory_info().rss / 1024 / 1024)
|
| 739 |
+
|
| 740 |
+
if torch.cuda.is_available():
|
| 741 |
+
torch.cuda.synchronize()
|
| 742 |
+
|
| 743 |
+
runtime = time() - start_time
|
| 744 |
+
|
| 745 |
+
print(f"Batched inference completed for {type(model).__name__}")
|
| 746 |
+
print(f"Total runtime: {runtime:.4f} seconds")
|
| 747 |
+
|
| 748 |
+
all_probs = np.array(all_probs)
|
| 749 |
+
|
| 750 |
+
f1 = f1_score(true_label, predictions, zero_division=0)
|
| 751 |
+
accuracy = accuracy_score(true_label, predictions)
|
| 752 |
+
precision = precision_score(true_label, predictions, zero_division=0)
|
| 753 |
+
recall = recall_score(true_label, predictions, zero_division=0)
|
| 754 |
+
cm = confusion_matrix(true_label, predictions).tolist()
|
| 755 |
+
|
| 756 |
+
avg_cpu = sum(cpu_percent_list) / len(cpu_percent_list) if cpu_percent_list else 0
|
| 757 |
+
avg_memory = sum(memory_percent_list) / len(memory_percent_list) if memory_percent_list else 0
|
| 758 |
+
peak_memory = max(memory_percent_list) if memory_percent_list else 0
|
| 759 |
+
peak_cpu = max(cpu_percent_list) if cpu_percent_list else 0
|
| 760 |
+
|
| 761 |
+
return {
|
| 762 |
+
'model_name': type(model).__name__,
|
| 763 |
+
'f1_score': f1,
|
| 764 |
+
'accuracy': accuracy,
|
| 765 |
+
'precision': precision,
|
| 766 |
+
'recall': recall,
|
| 767 |
+
'confusion_matrix': cm,
|
| 768 |
+
'cpu_usage': avg_cpu,
|
| 769 |
+
'memory_usage': avg_memory,
|
| 770 |
+
'peak_cpu_usage': peak_cpu,
|
| 771 |
+
'peak_memory_usage': peak_memory,
|
| 772 |
+
'runtime': runtime,
|
| 773 |
+
'all_probabilities': all_probs.tolist()
|
| 774 |
+
}
|
| 775 |
|
| 776 |
def predict_hatespeech(text, rationale, model, tokenizer_hatebert, tokenizer_rationale, config, device, model_type="altered"):
|
| 777 |
|