explainability-tool-for-aa / utils /llm_feat_utils.py
Anisha Bhatnagar
fixed bug in region computation for all task authors and reduced logging
3ad08b5
raw
history blame
4.93 kB
import json
import os
import hashlib
import time
from json import JSONDecodeError
import traceback
CACHE_DIR = "datasets/feature_spans_cache"
os.makedirs(CACHE_DIR, exist_ok=True)
import pandas as pd
#read and create the Gram2Vec feature set once
_g2v_df = pd.read_csv("datasets/gram2vec_feats.csv")
GRAM2VEC_SET = set(_g2v_df['gram2vec_feats'].unique())
MAX_ATTEMPTS = 3
WAIT_SECONDS = 2
# Bump this whenever there is a change prompt, feature space, etc...
CACHE_VERSION = 2
def _feat_hash(feature: str, text: str) -> str:
blob = json.dumps({
"version": CACHE_VERSION,
"text": text,
"features": feature
}, sort_keys=True).encode()
return hashlib.md5(blob).hexdigest()
def generate_feature_spans(client, text: str, features: list[str]) -> str:
print("Calling OpenAI to extract spans")
"""
Call to OpenAI to extract spans. Returns a JSON string.
"""
prompt = f"""You are a linguistic specialist. Given a writing sample and a list of descriptive features, identify the exact text spans that demonstrate each feature.
Important:
- The headers like "Document 1:" etc are NOT part of the original text — ignore them.
- For each feature, even if there is no match, return an empty list.
- Only return exact phrases from the text.
Respond in JSON format like:
{{
"feature1": ["span1", "span2"],
"feature2": [],
}}
Text:
\"\"\"{text}\"\"\"
Style Features:
{features}
"""
response = client.chat.completions.create(
model="gpt-4",
messages=[{"role":"user","content":prompt}],
temperature=0.3,
)
return response.choices[0].message.content
def generate_feature_spans_with_retries(client, text: str, features: list[str]) -> dict:
"""
Calls `generate_feature_spans` with retries on failure.
Returns the parsed JSON dict mapping feature->list[spans].
"""
for attempt in range(MAX_ATTEMPTS):
try:
response_str = generate_feature_spans(client, text, features)
print(response_str)
result = json.loads(response_str)
return result
except (JSONDecodeError, ValueError) as e:
print(f"Attempt {attempt+1} failed: {e}")
traceback.print_exc()
if attempt < MAX_ATTEMPTS - 1:
wait_sec = WAIT_SECONDS * (2 ** attempt)
print(f"Retrying after {wait_sec} seconds...")
time.sleep(wait_sec)
raise RuntimeError("All retry attempts failed for OpenAI call.")
def generate_feature_spans_cached(client, text: str, features: list[str], role: str = "mystery" ) -> dict:
"""
Computes a cache key from text + feature list,
then either loads or calls the API and saves to disk.
Returns the parsed JSON dict mapping feature->list[spans].
"""
print(f"Generating spans for ({role})")
# print(f"feature list {features}")
role = role.replace(" ", "_").replace("/", "_").replace("-", "_")
print(f"Cache dir: {CACHE_DIR}")
os.makedirs(CACHE_DIR, exist_ok=True)
cache_path = os.path.join(CACHE_DIR, f"{role}.json")
if os.path.exists(cache_path):
print(f"Cache hit....")
with open(cache_path) as f:
cache: dict[str, dict] = json.load(f)
else:
cache = {}
result: dict[str, list[str]] = {}
missing_feats: list[str] = []
missing_feats_count = 0
found_feats_count = 0
for feat in features:
if feat == "None":
result[feat] = []
continue
h = _feat_hash(feat, text)
if h in cache:
# print(f"Found feature: {feat}")
found_feats_count += 1
result[feat] = cache[h]["spans"]
else:
# print(f"Missing feature: {feat}")
missing_feats_count += 1
missing_feats.append(feat)
print(f"Found {found_feats_count} features in cache, {missing_feats_count} missing")
if missing_feats:
mapping = generate_feature_spans_with_retries(client, text, missing_feats)
# 4) update cache & result for each missing feature
for feat in missing_feats:
h = _feat_hash(feat, text)
spans = mapping.get(feat)
cache[h] = {
"feature": feat,
"spans": spans
}
result[feat] = spans
# 5) write back the combined cache
with open(cache_path, "w") as f:
json.dump(cache, f, indent=2)
return result
def split_features(all_feats):
"""
Given a list of mixed features, returns two lists:
- llm_feats: those NOT in the Gram2Vec CSV
- g2v_feats: those present in the CSV
"""
g2v_feats = [feat for feat in all_feats if feat in GRAM2VEC_SET]
llm_feats = [feat for feat in all_feats if feat not in GRAM2VEC_SET]
return llm_feats, g2v_feats