Spaces:
Sleeping
Sleeping
File size: 2,583 Bytes
e42e305 9d450de e42e305 9d450de e42e305 9d450de e42e305 9d450de 8ae98a8 e42e305 b6882bd e42e305 b6882bd e42e305 b6882bd e42e305 b6882bd e42e305 b6882bd 8ae98a8 e42e305 8ae98a8 e42e305 9d450de e42e305 8ae98a8 e42e305 9d450de e42e305 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
from datasets import load_dataset
from google import genai
from dotenv import load_dotenv
from retry_with_backoff import retry_with_backoff
from prompts import update_prompt
from evaluate import select_round
import logfire
# Load API keys
load_dotenv()
# This wraps Google Gen AI client calls
# to capture prompts, responses, and metadata
logfire.configure()
logfire.instrument_google_genai()
# Initialize the Gemini LLM
client = genai.Client()
@logfire.instrument("Update alignment")
def update_alignment(round=None):
"""
Update the alignment prompt using feedback collect from production app.
Args:
round: alignment round, starting with 2 (None uses most recent available round)
"""
# Load feedback dataset
dataset = load_dataset("jedick/noteworthy-differences-feedback", split="train")
# Convert to DataFrame
df = dataset.to_pandas()
# Get examples for this round
# This also gets the number of the most recent round if the argument is None
index, round = select_round(dataset, "train", round)
examples = df.iloc[index]
feedback_data = []
# Loop over rows
for index, row in examples.iterrows():
# Construct training text for this row
ground_truth = "noteworthy=False"
if row["judge_noteworthy"] and row["feedback"] == "agree":
ground_truth = "noteworthy=True"
if not row["judge_noteworthy"] and row["feedback"] == "disagree":
ground_truth = "noteworthy=True"
judge = f"AI Judge: {row['judge_reasoning']}"
human = f"Human feedback: {row['feedback']} ({ground_truth})."
row_text = f"{judge} {human}"
feedback_data.append(row_text)
feedback_data = "\n\n".join(feedback_data)
# Read the existing alignment
with open(f"production/alignment_{str(round - 1)}.txt", "r") as file:
lines = file.readlines()
alignment_text = "".join(lines)
# Write prompt to update alignment
prompt = update_prompt.replace("{{alignment_text}}", alignment_text).replace(
"{{feedback_data}}", feedback_data
)
# Function to generate response
@retry_with_backoff()
def get_response():
response = client.models.generate_content(
model="gemini-2.5-flash",
contents=prompt,
)
return response
# Get the response
response = get_response()
# Save to new alignment text file
with open(f"production/alignment_{str(round)}.txt", "w") as file:
file.write(response.text)
if __name__ == "__main__":
update_alignment()
|