| """Cache zero-shot API emotion scores for SemEval-2007 Affective Text.""" |
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import logging |
| import os |
| import re |
| import time |
| import urllib.error |
| import urllib.parse |
| import urllib.request |
| from pathlib import Path |
|
|
| import sys |
| sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) |
|
|
| from src.data import EMOTION_NAMES, load_affective_text, load_prediction_cache |
|
|
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") |
| log = logging.getLogger(__name__) |
|
|
| PROMPT_TEMPLATE = ( |
| 'Rate the following news headline on 6 emotions: anger, disgust, fear, joy, sadness, surprise. ' |
| 'Return only 6 numbers from 0 to 100, comma-separated, in that order.\n' |
| 'Headline: "{headline}"\n' |
| "Scores:" |
| ) |
|
|
|
|
| def parse_scores(text: str) -> list[float]: |
| nums = re.findall(r"-?\d+(?:\.\d+)?", text) |
| if len(nums) < 6: |
| raise ValueError(f"Could not parse 6 scores from response: {text!r}") |
| scores = [max(float(x), 0.0) for x in nums[:6]] |
| if sum(scores) <= 0: |
| raise ValueError(f"Parsed zero-sum scores from response: {text!r}") |
| return scores |
|
|
|
|
| def call_openai_chat_completions( |
| headline: str, |
| model: str, |
| api_key: str, |
| base_url: str, |
| timeout_sec: float, |
| ) -> tuple[str, dict]: |
| prompt = PROMPT_TEMPLATE.format(headline=headline) |
| payload = { |
| "model": model, |
| "messages": [ |
| {"role": "system", "content": "You are a precise annotation model."}, |
| {"role": "user", "content": prompt}, |
| ], |
| "temperature": 0, |
| } |
| req = urllib.request.Request( |
| url=base_url.rstrip("/") + "/chat/completions", |
| data=json.dumps(payload).encode("utf-8"), |
| headers={ |
| "Content-Type": "application/json", |
| "Authorization": f"Bearer {api_key}", |
| }, |
| method="POST", |
| ) |
| with urllib.request.urlopen(req, timeout=timeout_sec) as resp: |
| body = json.loads(resp.read().decode("utf-8")) |
| text = body["choices"][0]["message"]["content"] |
| return text, body |
|
|
|
|
| def call_gemini_generate_content( |
| headline: str, |
| model: str, |
| api_key: str, |
| base_url: str, |
| timeout_sec: float, |
| ) -> tuple[str, dict]: |
| prompt = PROMPT_TEMPLATE.format(headline=headline) |
| payload = { |
| "contents": [ |
| { |
| "role": "user", |
| "parts": [{"text": prompt}], |
| } |
| ], |
| "generationConfig": { |
| "temperature": 0, |
| }, |
| } |
| url = ( |
| base_url.rstrip("/") |
| + f"/models/{model}:generateContent?key={urllib.parse.quote(api_key)}" |
| ) |
| req = urllib.request.Request( |
| url=url, |
| data=json.dumps(payload).encode("utf-8"), |
| headers={"Content-Type": "application/json"}, |
| method="POST", |
| ) |
| with urllib.request.urlopen(req, timeout=timeout_sec) as resp: |
| body = json.loads(resp.read().decode("utf-8")) |
| candidates = body.get("candidates", []) |
| if not candidates: |
| raise KeyError(f"No Gemini candidates in response: {body}") |
| parts = candidates[0].get("content", {}).get("parts", []) |
| text = "\n".join(part.get("text", "") for part in parts if part.get("text")) |
| if not text: |
| raise KeyError(f"No text parts in Gemini response: {body}") |
| return text, body |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--data-dir", default="data/raw/AffectiveText.Semeval.2007") |
| parser.add_argument("--output", default="data/processed/affective_text_predictions.jsonl") |
| parser.add_argument("--provider", choices=["openai", "gemini"], default="gemini") |
| parser.add_argument("--model", default=None) |
| parser.add_argument("--base-url", default=None) |
| parser.add_argument("--api-key-env", default=None) |
| parser.add_argument("--limit", type=int, default=None) |
| parser.add_argument("--sleep-sec", type=float, default=0.0) |
| parser.add_argument("--timeout-sec", type=float, default=60.0) |
| parser.add_argument("--overwrite", action="store_true") |
| args = parser.parse_args() |
|
|
| if args.model is None: |
| if args.provider == "gemini": |
| args.model = os.environ.get("GEMINI_MODEL", "gemini-2.0-flash-001") |
| else: |
| args.model = os.environ.get("OPENAI_MODEL", "gpt-4o-mini-2024-07-18") |
| if args.base_url is None: |
| if args.provider == "gemini": |
| args.base_url = os.environ.get("GEMINI_BASE_URL", "https://generativelanguage.googleapis.com/v1beta") |
| else: |
| args.base_url = os.environ.get("OPENAI_BASE_URL", "https://api.openai.com/v1") |
| if args.api_key_env is None: |
| args.api_key_env = "GEMINI_API_KEY" if args.provider == "gemini" else "OPENAI_API_KEY" |
|
|
| api_key = os.environ.get(args.api_key_env) |
| if not api_key: |
| raise EnvironmentError(f"Missing API key in env var {args.api_key_env}") |
|
|
| data = load_affective_text(args.data_dir) |
| ids = data["ids"] |
| headlines = data["headlines"] |
| if args.limit is not None: |
| ids = ids[:args.limit] |
| headlines = headlines[:args.limit] |
|
|
| out_path = Path(args.output) |
| out_path.parent.mkdir(parents=True, exist_ok=True) |
| existing = {} |
| if out_path.exists() and not args.overwrite: |
| existing = load_prediction_cache(out_path) |
| log.info(f"Loaded {len(existing)} cached predictions from {out_path}") |
|
|
| n_done = 0 |
| with open(out_path, "a" if existing and not args.overwrite else "w", encoding="utf-8") as f: |
| for idx, headline in zip(ids, headlines): |
| if idx in existing and not args.overwrite: |
| continue |
| try: |
| if args.provider == "gemini": |
| raw_text, raw_json = call_gemini_generate_content( |
| headline=headline, |
| model=args.model, |
| api_key=api_key, |
| base_url=args.base_url, |
| timeout_sec=args.timeout_sec, |
| ) |
| else: |
| raw_text, raw_json = call_openai_chat_completions( |
| headline=headline, |
| model=args.model, |
| api_key=api_key, |
| base_url=args.base_url, |
| timeout_sec=args.timeout_sec, |
| ) |
| scores = parse_scores(raw_text) |
| except (urllib.error.URLError, urllib.error.HTTPError, ValueError, KeyError) as exc: |
| log.error(f"Failed on id={idx}: {exc}") |
| continue |
|
|
| row = { |
| "id": idx, |
| "headline": headline, |
| "emotions": EMOTION_NAMES, |
| "scores": scores, |
| "provider": args.provider, |
| "model": args.model, |
| "base_url": args.base_url, |
| "prompt_template": PROMPT_TEMPLATE, |
| "raw_text": raw_text, |
| "raw_response": raw_json, |
| } |
| f.write(json.dumps(row, ensure_ascii=True) + "\n") |
| f.flush() |
| n_done += 1 |
| if n_done % 50 == 0: |
| log.info(f"Cached {n_done} new predictions") |
| if args.sleep_sec > 0: |
| time.sleep(args.sleep_sec) |
|
|
| log.info(f"Finished. Predictions cached at {out_path}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|