Upload README.md with huggingface_hub
Browse files
README.md
CHANGED
|
@@ -86,6 +86,70 @@ Evaluation signals: ROUGE for summaries; Accuracy/Precision/Recall/F1 for classi
|
|
| 86 |
|
| 87 |
This setup lets one checkpoint handle both analysis (populism flag) and explanation (summary) with simple instruction prefixes.
|
| 88 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
## Citation:
|
| 90 |
|
| 91 |
@article{dickson2024going,
|
|
|
|
| 86 |
|
| 87 |
This setup lets one checkpoint handle both analysis (populism flag) and explanation (summary) with simple instruction prefixes.
|
| 88 |
|
| 89 |
+
## Usage:
|
| 90 |
+
|
| 91 |
+
install dependency:
|
| 92 |
+
Bash: pip install transformers
|
| 93 |
+
|
| 94 |
+
then run:
|
| 95 |
+
|
| 96 |
+
import torch
|
| 97 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 98 |
+
|
| 99 |
+
MODEL_ID = "tdickson17/Populism_detection"
|
| 100 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 101 |
+
|
| 102 |
+
tok = AutoTokenizer.from_pretrained(MODEL_ID)
|
| 103 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID).to(device).eval()
|
| 104 |
+
|
| 105 |
+
MAX_SRC, MAX_SUM = 1024, 128
|
| 106 |
+
DEC_START = model.config.decoder_start_token_id
|
| 107 |
+
ID0 = tok("0", add_special_tokens=False)["input_ids"][0]
|
| 108 |
+
ID1 = tok("1", add_special_tokens=False)["input_ids"][0]
|
| 109 |
+
|
| 110 |
+
THRESHOLD = 0.5 # raise for higher precision, lower for higher recall
|
| 111 |
+
POSITIVE_MSG = "This text DOES contain populist sentiment.\n"
|
| 112 |
+
NEGATIVE_MSG = "Populist sentiment is NOT detected in this text.\n"
|
| 113 |
+
|
| 114 |
+
GEN_SUM = dict(
|
| 115 |
+
do_sample=False, num_beams=5,
|
| 116 |
+
max_new_tokens=MAX_SUM, min_new_tokens=16,
|
| 117 |
+
length_penalty=1.1, no_repeat_ngram_size=3
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
@torch.no_grad()
|
| 121 |
+
def summarize(text: str) -> str:
|
| 122 |
+
enc = tok("summarize: " + text, return_tensors="pt",
|
| 123 |
+
truncation=True, max_length=MAX_SRC).to(device)
|
| 124 |
+
out = model.generate(**enc, **GEN_SUM)
|
| 125 |
+
s = tok.decode(out[0], skip_special_tokens=True).strip()
|
| 126 |
+
if s.lower().startswith("summarize:"):
|
| 127 |
+
s = s.split(":", 1)[1].strip()
|
| 128 |
+
return s
|
| 129 |
+
|
| 130 |
+
@torch.no_grad()
|
| 131 |
+
def classify_populism_prob(text: str) -> float:
|
| 132 |
+
enc = tok("classify_populism: " + text, return_tensors="pt",
|
| 133 |
+
truncation=True, max_length=MAX_SRC).to(device)
|
| 134 |
+
dec_inp = torch.tensor([[DEC_START]], device=device)
|
| 135 |
+
logits = model(**enc, decoder_input_ids=dec_inp, use_cache=False).logits[:, -1, :]
|
| 136 |
+
|
| 137 |
+
two = torch.stack([logits[:, ID0], logits[:, ID1]], dim=-1)
|
| 138 |
+
p1 = torch.softmax(two, dim=-1)[0, 1].item()
|
| 139 |
+
return p1
|
| 140 |
+
|
| 141 |
+
def classify_populism_label(text: str, threshold: float = THRESHOLD, include_probability: bool = True) -> str:
|
| 142 |
+
p1 = classify_populism_prob(text)
|
| 143 |
+
msg = POSITIVE_MSG if p1 >= threshold else NEGATIVE_MSG
|
| 144 |
+
return f"{msg} Confidence={p1:.3f}%" if include_probability else msg
|
| 145 |
+
|
| 146 |
+
# Example
|
| 147 |
+
text = """<Insert Text here>"""
|
| 148 |
+
print(classify_populism_label(text))
|
| 149 |
+
print("\nSummary:\n", summarize(text))
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
|
| 153 |
## Citation:
|
| 154 |
|
| 155 |
@article{dickson2024going,
|