Alon Albalak
initial commmit
71a764a
raw
history blame
4.38 kB
"""Data loading and saving functionality"""
import json
import os
import random
import datetime
class DataManager:
"""Manages loading and saving of prompts and results data"""
def __init__(self):
self.prompts_data = []
def load_prompts_data(self, filepath="data/prompts.jsonl"):
"""Load prompts data from JSONL file"""
with open(filepath, "r") as f:
self.prompts_data = [json.loads(line) for line in f]
def get_random_prompt(self):
"""Get a random prompt from loaded data"""
if not self.prompts_data:
raise RuntimeError("No prompts data loaded. Call load_prompts_data() first.")
return random.choice(self.prompts_data)
def load_results_data(self, filepath="data/results.jsonl"):
"""Load all results data from results.jsonl file."""
results = []
try:
with open(filepath, "r") as f:
for line in f:
if line.strip():
results.append(json.loads(line))
except FileNotFoundError:
pass # Return empty list if file doesn't exist
return results
def save_interaction(self, prompt_data, user_continuation, generated_response,
cosine_distance, session_id, num_user_tokens, filepath="data/results.jsonl"):
"""Save a user interaction to the results file"""
interaction = {
"prompt": prompt_data["prompt"],
"model": prompt_data["model"],
"llm_partial_response": prompt_data["llm_partial_response"],
"llm_full_response_original": prompt_data["llm_full_response_original"],
"user_continuation": user_continuation,
"full_response_from_user": generated_response,
"cosine_distance": cosine_distance,
"timestamp": datetime.datetime.now().isoformat(),
"continuation_source": session_id,
"num_user_tokens": num_user_tokens
}
os.makedirs(os.path.dirname(filepath), exist_ok=True)
with open(filepath, "a") as f:
f.write(json.dumps(interaction) + "\n")
def filter_results_by_partial_response(self, results, prompt, partial_response):
"""Filter results to only include entries for the current prompt."""
return [r for r in results if r["prompt"] == prompt and r["llm_partial_response"] == partial_response]
def filter_results_by_session(self, results, session_id):
"""Filter results to only include entries from the specified session."""
return [r for r in results if r.get("continuation_source") == session_id]
def get_gallery_responses(self, min_score=0.3, limit=20):
"""Get gallery responses with minimum creativity score"""
all_results = self.load_results_data()
# Filter by minimum score and sort by score (descending)
filtered_results = [r for r in all_results if r["cosine_distance"] >= min_score]
filtered_results.sort(key=lambda x: x["cosine_distance"], reverse=True)
# Return top results
return filtered_results[:limit]
def get_inspire_me_examples(self, prompt, partial_response, limit=5):
"""Get inspiring examples for the current prompt"""
all_results = self.load_results_data()
# Filter to current prompt and get good examples (≥0.2 score)
examples = [r for r in all_results
if r["prompt"] == prompt
and r["llm_partial_response"] == partial_response
and r["cosine_distance"] >= 0.2]
# Sort by creativity score and return random sample
examples.sort(key=lambda x: x["cosine_distance"], reverse=True)
# Get diverse examples (not just the top ones)
if len(examples) > limit:
# Take some from top, some from middle range
top_examples = examples[:min(3, len(examples))]
remaining = examples[3:]
if remaining:
additional = random.sample(remaining, min(limit-len(top_examples), len(remaining)))
examples = top_examples + additional
else:
examples = top_examples
return random.sample(examples, min(limit, len(examples)))