writer-app / app.py
NealCaren's picture
Update app.py
22bc47a verified
import gradio as gr
import pandas as pd
import numpy as np
from sentence_transformers import SentenceTransformer
import anthropic
import json
import re
import os
from pathlib import Path
# Initialize the SentenceTransformer model
model = SentenceTransformer('nomic-ai/nomic-embed-text-v1.5', trust_remote_code=True)
def load_json_prompts(file_path):
with open(file_path, 'r', encoding='utf-8') as file:
return json.load(file)
def get_embeddings(chunk):
"""Get embedding for chunk using SentenceTransformer"""
embedding = model.encode(chunk)
return embedding
def search_similar(query_embedding, chunk_df, k=6, max_per_citation=2):
"""Search chunks using cosine similarity with citation limit"""
similarities = []
for idx, row in chunk_df.iterrows():
sim = np.dot(query_embedding, row['embeddings']) / (
np.linalg.norm(query_embedding) * np.linalg.norm(row['embeddings'])
)
similarities.append({
'citation_id': f"[{idx + 1}]",
'citation': row['citation'],
'text': row['text_chunk'],
'chunk_label': row['chunk_label'],
'similarity': sim
})
similarities = sorted(similarities, key=lambda x: x['similarity'], reverse=True)
citations_count = {}
filtered_results = []
for result in similarities:
citation = result['citation']
citations_count[citation] = citations_count.get(citation, 0)
if citations_count[citation] < max_per_citation:
citations_count[citation] += 1
filtered_results.append({
'citation': result['citation'],
'text': result['text'],
'chunk_label': result['chunk_label']
})
if len(filtered_results) == k:
break
return filtered_results
def naive_search(thesis, context, client):
message = client.messages.create(
model="claude-3-sonnet-20240229",
max_tokens=3000,
temperature=0,
system=naive_system_prompt,
messages=[{
"role": "user",
"content": [{
"type": "text",
"text": f"Here are the references you should use:<references>\n{context}\n</references>\n\nHere is the topic you need to write about:\n<topic>\n{thesis}\n</topic>\nReturn only the text without preface."
}]
}]
)
return message.content[0].text
def section_drafter(thesis, context, style_notes, client):
if len(style_notes) > 0:
style_notes = f'''Here's some **important** style guidelines::\n\n{style_notes}'''
message = client.messages.create(
model="claude-3-5-sonnet-20241022",
max_tokens=5000,
temperature=0,
system=section_draft_prompt,
messages=[{
"role": "user",
"content": [{
"type": "text",
"text": f"Write a part of this literature review. {style_notes} Here are the references you should use:<references>\n{context}\n</references>\n\nWe are currently working on this section:\n<prompt>\n{thesis}\n</prompt>\nReturn only the <text> prefaced only with an approriate markdown subheader for the specific section. The text should be comprehensive and detailed, being sure to cite existing work and the work it enganges with."
}]
}]
)
return message.content[0].text
def format_asa_citation(bibtex_string):
"""Converts a BibTeX string to an ASA style citation."""
author_match = re.search(r"author\s*=\s*{(.*?)}", bibtex_string, re.DOTALL)
title_match = re.search(r"title\s*=\s*{(.*?)}", bibtex_string, re.DOTALL)
year_match = re.search(r"year\s*=\s*{(.*?)}", bibtex_string, re.DOTALL)
journal_match = re.search(r"journal\s*=\s*{(.*?)}", bibtex_string, re.DOTALL)
volume_match = re.search(r"volume\s*=\s*{(.*?)}", bibtex_string, re.DOTALL)
number_match = re.search(r"number\s*=\s*{(.*?)}", bibtex_string, re.DOTALL)
pages_match = re.search(r"pages\s*=\s*{(.*?)}", bibtex_string, re.DOTALL)
author = author_match.group(1).strip() if author_match else ""
title = title_match.group(1).strip() if title_match else ""
year = year_match.group(1).strip() if year_match else ""
journal = journal_match.group(1).strip() if journal_match else ""
volume = volume_match.group(1).strip() if volume_match else ""
number = number_match.group(1).strip() if number_match else ""
pages = pages_match.group(1).strip() if pages_match else ""
citation = f"{author}. {year}. {title}. {journal} {volume}({number}): {pages}."
return citation
def extract_cites(context):
cites = [format_asa_citation(item['citation']) for item in context]
cites = list(set(cites))
return '\n'.join([f'* {cite}' for cite in cites])
def generate_literature_review(thesis, style_notes, api_key, progress=gr.Progress()):
yield gr.update(value="")
output = []
# Check if using the special password "quote" to load from environment
if api_key == "quote":
api_key = os.getenv("ANTHROPIC_API_KEY")
if not api_key:
raise ValueError("Environment variable ANTHROPIC_API_KEY not found")
# Initialize Anthropic client
client = anthropic.Anthropic(api_key=api_key)
# Load data
progress(0.1, desc="Loading document chunks...")
# First attempt
progress(0.2, desc="Finding initial references...")
context = search_similar(get_embeddings(thesis), chunk_df, k=8)
first_cites = extract_cites(context)
progress(0.4, desc="Generating first draft...")
naive_results = naive_search(thesis, context, client)
# Second attempt
progress(0.6, desc="Finding additional references...")
context2 = search_similar(get_embeddings(naive_results), chunk_df, k=16)
text_chunks = [c['text'] for c in context]
combo_context = context + [c for c in context2 if c['text'] not in text_chunks]
final_cites = extract_cites(combo_context)
ref_count = len(final_cites.split('\n'))
progress(0.8, desc=f"Generating final draft from {ref_count} sources...")
draft = section_drafter(thesis, combo_context, style_notes, client)
output.append(draft)
output.append("\n## Sources\n" + final_cites)
progress(1.0, desc="Complete!")
yield "\n\n".join(output)
def create_interface():
theme = gr.themes.Soft(
primary_hue="slate",
)
with gr.Blocks(theme=theme) as app:
gr.Markdown("# CiteCraft")
intro_text = f"Using {chunk_count} pages from {source_count} sources."
gr.Markdown(intro_text)
with gr.Row():
with gr.Column(scale=1):
thesis_input = gr.Textbox(
label="Section Topic",
placeholder="Enter your research question or section theme here",
lines=4
)
style_notes_input = gr.Textbox(
label="Style notes (Optional)",
placeholder="Enter any writing style modifications (optional)",
lines=4
)
api_key = gr.Textbox(
label="Anthropic API Key",
placeholder="Enter your Anthropic API key",
type="password"
)
generate_button = gr.Button("Generate Review", variant="primary")
with gr.Column(scale=2):
output_text = gr.Markdown(label="Generated Review")
generate_button.click(
generate_literature_review,
inputs=[thesis_input, style_notes_input, api_key],
outputs=output_text
)
return app
if __name__ == "__main__":
# Load necessary prompts and configurations
try:
json_prompts = load_json_prompts('prompts.json')
print("Successfully loaded prompts.json")
# Get system prompts and protest file name
section_draft_prompt = json_prompts['section_draft_prompt']
naive_system_prompt = json_prompts['naive_system_prompt']
protest_file_name = json_prompts['protest_file_name']
print(f"Will load document chunks from: {protest_file_name}")
chunk_df = pd.read_json(protest_file_name)
chunk_count = len(chunk_df)
source_count = len(chunk_df['citation'].value_counts())
except Exception as e:
print(f"Error loading prompts.json: {e}")
raise
# Launch the app
app = create_interface()
app.launch(share=True) # share=True creates a public URL