tdickson17 commited on
Commit
8c8d570
·
verified ·
1 Parent(s): 4df1b82

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +64 -0
README.md CHANGED
@@ -86,6 +86,70 @@ Evaluation signals: ROUGE for summaries; Accuracy/Precision/Recall/F1 for classi
86
 
87
  This setup lets one checkpoint handle both analysis (populism flag) and explanation (summary) with simple instruction prefixes.
88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  ## Citation:
90
 
91
  @article{dickson2024going,
 
86
 
87
  This setup lets one checkpoint handle both analysis (populism flag) and explanation (summary) with simple instruction prefixes.
88
 
89
+ ## Usage:
90
+
91
+ install dependency:
92
+ Bash: pip install transformers
93
+
94
+ then run:
95
+
96
+ import torch
97
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
98
+
99
+ MODEL_ID = "tdickson17/Populism_detection"
100
+ device = "cuda" if torch.cuda.is_available() else "cpu"
101
+
102
+ tok = AutoTokenizer.from_pretrained(MODEL_ID)
103
+ model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID).to(device).eval()
104
+
105
+ MAX_SRC, MAX_SUM = 1024, 128
106
+ DEC_START = model.config.decoder_start_token_id
107
+ ID0 = tok("0", add_special_tokens=False)["input_ids"][0]
108
+ ID1 = tok("1", add_special_tokens=False)["input_ids"][0]
109
+
110
+ THRESHOLD = 0.5 # raise for higher precision, lower for higher recall
111
+ POSITIVE_MSG = "This text DOES contain populist sentiment.\n"
112
+ NEGATIVE_MSG = "Populist sentiment is NOT detected in this text.\n"
113
+
114
+ GEN_SUM = dict(
115
+ do_sample=False, num_beams=5,
116
+ max_new_tokens=MAX_SUM, min_new_tokens=16,
117
+ length_penalty=1.1, no_repeat_ngram_size=3
118
+ )
119
+
120
+ @torch.no_grad()
121
+ def summarize(text: str) -> str:
122
+ enc = tok("summarize: " + text, return_tensors="pt",
123
+ truncation=True, max_length=MAX_SRC).to(device)
124
+ out = model.generate(**enc, **GEN_SUM)
125
+ s = tok.decode(out[0], skip_special_tokens=True).strip()
126
+ if s.lower().startswith("summarize:"):
127
+ s = s.split(":", 1)[1].strip()
128
+ return s
129
+
130
+ @torch.no_grad()
131
+ def classify_populism_prob(text: str) -> float:
132
+ enc = tok("classify_populism: " + text, return_tensors="pt",
133
+ truncation=True, max_length=MAX_SRC).to(device)
134
+ dec_inp = torch.tensor([[DEC_START]], device=device)
135
+ logits = model(**enc, decoder_input_ids=dec_inp, use_cache=False).logits[:, -1, :]
136
+
137
+ two = torch.stack([logits[:, ID0], logits[:, ID1]], dim=-1)
138
+ p1 = torch.softmax(two, dim=-1)[0, 1].item()
139
+ return p1
140
+
141
+ def classify_populism_label(text: str, threshold: float = THRESHOLD, include_probability: bool = True) -> str:
142
+ p1 = classify_populism_prob(text)
143
+ msg = POSITIVE_MSG if p1 >= threshold else NEGATIVE_MSG
144
+ return f"{msg} Confidence={p1:.3f}%" if include_probability else msg
145
+
146
+ # Example
147
+ text = """<Insert Text here>"""
148
+ print(classify_populism_label(text))
149
+ print("\nSummary:\n", summarize(text))
150
+
151
+
152
+
153
  ## Citation:
154
 
155
  @article{dickson2024going,