Anisha Bhatnagar
commited on
Commit
·
bd7d9f9
1
Parent(s):
8db24a7
fixed caching issues in LLM feature identification
Browse files- utils/interp_space_utils.py +43 -3
- utils/llm_feat_utils.py +1 -1
utils/interp_space_utils.py
CHANGED
|
@@ -22,7 +22,9 @@ import numpy as np
|
|
| 22 |
from sklearn.metrics.pairwise import cosine_similarity
|
| 23 |
|
| 24 |
CACHE_DIR = "datasets/embeddings_cache"
|
|
|
|
| 25 |
os.makedirs(CACHE_DIR, exist_ok=True)
|
|
|
|
| 26 |
# Bump this whenever there is a change etc...
|
| 27 |
CACHE_VERSION = 1
|
| 28 |
|
|
@@ -418,7 +420,34 @@ def compute_clusters_style_representation_2(
|
|
| 418 |
|
| 419 |
return parsed_response
|
| 420 |
|
| 421 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 422 |
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
| 423 |
prompt = f"""Identify {max_num_feats} writing style features that are commonly found across the following texts. Do not extract spans. Just return the feature names as a list.
|
| 424 |
Author Texts:
|
|
@@ -442,7 +471,18 @@ def identify_style_features(author_texts: list[str], max_num_feats: int = 5) ->
|
|
| 442 |
)
|
| 443 |
return json.loads(response.choices[0].message.content)
|
| 444 |
|
| 445 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 446 |
|
| 447 |
def retry_call(call_fn, schema_class, max_attempts=3, wait_sec=2):
|
| 448 |
for attempt in range(max_attempts):
|
|
@@ -494,7 +534,7 @@ def compute_clusters_style_representation_3(
|
|
| 494 |
author_names = background_corpus_df_feat_id[cluster_label_clm_name].tolist()[:max_num_authors]
|
| 495 |
print(f"Number of authors: {len(background_corpus_df_feat_id)}")
|
| 496 |
print(author_names)
|
| 497 |
-
features = identify_style_features(author_texts, max_num_feats=max_num_feats)
|
| 498 |
|
| 499 |
# STEP 2: Prepare author pool for span extraction
|
| 500 |
span_df = background_corpus_df.iloc[:4]
|
|
|
|
| 22 |
from sklearn.metrics.pairwise import cosine_similarity
|
| 23 |
|
| 24 |
CACHE_DIR = "datasets/embeddings_cache"
|
| 25 |
+
ZOOM_CACHE = "datasets/zoom_cache/features_cache.json"
|
| 26 |
os.makedirs(CACHE_DIR, exist_ok=True)
|
| 27 |
+
os.makedirs(os.path.dirname(ZOOM_CACHE), exist_ok=True)
|
| 28 |
# Bump this whenever there is a change etc...
|
| 29 |
CACHE_VERSION = 1
|
| 30 |
|
|
|
|
| 420 |
|
| 421 |
return parsed_response
|
| 422 |
|
| 423 |
+
def generate_cache_key(author_names: List[str], max_num_feats: int) -> str:
|
| 424 |
+
"""Generate a unique cache key based on author names and max features"""
|
| 425 |
+
# Sort author names to ensure consistent key regardless of order
|
| 426 |
+
sorted_authors = sorted(author_names)
|
| 427 |
+
key_data = {
|
| 428 |
+
"authors": sorted_authors,
|
| 429 |
+
"max_num_feats": max_num_feats
|
| 430 |
+
}
|
| 431 |
+
key_string = json.dumps(key_data, sort_keys=True)
|
| 432 |
+
return hashlib.md5(key_string.encode()).hexdigest()
|
| 433 |
+
|
| 434 |
+
def identify_style_features(author_texts: list[str], author_names: list[str], max_num_feats: int = 5) -> list[str]:
|
| 435 |
+
cache_key = None
|
| 436 |
+
if author_names:
|
| 437 |
+
cache_key = generate_cache_key(author_names, max_num_feats)
|
| 438 |
+
|
| 439 |
+
if os.path.exists(ZOOM_CACHE):
|
| 440 |
+
with open(ZOOM_CACHE, 'r') as f:
|
| 441 |
+
cache = json.load(f)
|
| 442 |
+
else:
|
| 443 |
+
cache = {}
|
| 444 |
+
|
| 445 |
+
if cache_key in cache:
|
| 446 |
+
print(f"\nCache hit! Using cached features for authors: {author_names}")
|
| 447 |
+
return cache[cache_key]["features"]
|
| 448 |
+
else:
|
| 449 |
+
print(f"Cache miss. Computing features for authors: {author_names}")
|
| 450 |
+
|
| 451 |
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
| 452 |
prompt = f"""Identify {max_num_feats} writing style features that are commonly found across the following texts. Do not extract spans. Just return the feature names as a list.
|
| 453 |
Author Texts:
|
|
|
|
| 471 |
)
|
| 472 |
return json.loads(response.choices[0].message.content)
|
| 473 |
|
| 474 |
+
features = retry_call(_make_call, FeatureIdentificationSchema).features
|
| 475 |
+
|
| 476 |
+
print(f"Adding to zoom cache")
|
| 477 |
+
if cache_key and author_names:
|
| 478 |
+
cache[cache_key] = {
|
| 479 |
+
"features": features
|
| 480 |
+
}
|
| 481 |
+
# save_cache(cache)
|
| 482 |
+
with open(ZOOM_CACHE, 'w') as f:
|
| 483 |
+
json.dump(cache, f, indent=2)
|
| 484 |
+
|
| 485 |
+
print(f"Cached features for authors: {author_names}")
|
| 486 |
|
| 487 |
def retry_call(call_fn, schema_class, max_attempts=3, wait_sec=2):
|
| 488 |
for attempt in range(max_attempts):
|
|
|
|
| 534 |
author_names = background_corpus_df_feat_id[cluster_label_clm_name].tolist()[:max_num_authors]
|
| 535 |
print(f"Number of authors: {len(background_corpus_df_feat_id)}")
|
| 536 |
print(author_names)
|
| 537 |
+
features = identify_style_features(author_texts, author_names, max_num_feats=max_num_feats)
|
| 538 |
|
| 539 |
# STEP 2: Prepare author pool for span extraction
|
| 540 |
span_df = background_corpus_df.iloc[:4]
|
utils/llm_feat_utils.py
CHANGED
|
@@ -125,7 +125,7 @@ def generate_feature_spans_cached(client, text: str, features: list[str], role:
|
|
| 125 |
result[feat] = spans
|
| 126 |
|
| 127 |
# 5) write back the combined cache
|
| 128 |
-
with open(cache_path, "
|
| 129 |
json.dump(cache, f, indent=2)
|
| 130 |
return result
|
| 131 |
|
|
|
|
| 125 |
result[feat] = spans
|
| 126 |
|
| 127 |
# 5) write back the combined cache
|
| 128 |
+
with open(cache_path, "w") as f:
|
| 129 |
json.dump(cache, f, indent=2)
|
| 130 |
return result
|
| 131 |
|