Spaces:
Build error
Build error
| import gradio as gr | |
| import pandas as pd | |
| import numpy as np | |
| from sentence_transformers import SentenceTransformer | |
| import anthropic | |
| import json | |
| import re | |
| import os | |
| from pathlib import Path | |
| # Initialize the SentenceTransformer model | |
| model = SentenceTransformer('nomic-ai/nomic-embed-text-v1.5', trust_remote_code=True) | |
| def load_json_prompts(file_path): | |
| with open(file_path, 'r', encoding='utf-8') as file: | |
| return json.load(file) | |
| def get_embeddings(chunk): | |
| """Get embedding for chunk using SentenceTransformer""" | |
| embedding = model.encode(chunk) | |
| return embedding | |
| def search_similar(query_embedding, chunk_df, k=6, max_per_citation=2): | |
| """Search chunks using cosine similarity with citation limit""" | |
| similarities = [] | |
| for idx, row in chunk_df.iterrows(): | |
| sim = np.dot(query_embedding, row['embeddings']) / ( | |
| np.linalg.norm(query_embedding) * np.linalg.norm(row['embeddings']) | |
| ) | |
| similarities.append({ | |
| 'citation_id': f"[{idx + 1}]", | |
| 'citation': row['citation'], | |
| 'text': row['text_chunk'], | |
| 'chunk_label': row['chunk_label'], | |
| 'similarity': sim | |
| }) | |
| similarities = sorted(similarities, key=lambda x: x['similarity'], reverse=True) | |
| citations_count = {} | |
| filtered_results = [] | |
| for result in similarities: | |
| citation = result['citation'] | |
| citations_count[citation] = citations_count.get(citation, 0) | |
| if citations_count[citation] < max_per_citation: | |
| citations_count[citation] += 1 | |
| filtered_results.append({ | |
| 'citation': result['citation'], | |
| 'text': result['text'], | |
| 'chunk_label': result['chunk_label'] | |
| }) | |
| if len(filtered_results) == k: | |
| break | |
| return filtered_results | |
| def naive_search(thesis, context, client): | |
| message = client.messages.create( | |
| model="claude-3-sonnet-20240229", | |
| max_tokens=3000, | |
| temperature=0, | |
| system=naive_system_prompt, | |
| messages=[{ | |
| "role": "user", | |
| "content": [{ | |
| "type": "text", | |
| "text": f"Here are the references you should use:<references>\n{context}\n</references>\n\nHere is the topic you need to write about:\n<topic>\n{thesis}\n</topic>\nReturn only the text without preface." | |
| }] | |
| }] | |
| ) | |
| return message.content[0].text | |
| def section_drafter(thesis, context, style_notes, client): | |
| if len(style_notes) > 0: | |
| style_notes = f'''Here's some **important** style guidelines::\n\n{style_notes}''' | |
| message = client.messages.create( | |
| model="claude-3-5-sonnet-20241022", | |
| max_tokens=5000, | |
| temperature=0, | |
| system=section_draft_prompt, | |
| messages=[{ | |
| "role": "user", | |
| "content": [{ | |
| "type": "text", | |
| "text": f"Write a part of this literature review. {style_notes} Here are the references you should use:<references>\n{context}\n</references>\n\nWe are currently working on this section:\n<prompt>\n{thesis}\n</prompt>\nReturn only the <text> prefaced only with an approriate markdown subheader for the specific section. The text should be comprehensive and detailed, being sure to cite existing work and the work it enganges with." | |
| }] | |
| }] | |
| ) | |
| return message.content[0].text | |
| def format_asa_citation(bibtex_string): | |
| """Converts a BibTeX string to an ASA style citation.""" | |
| author_match = re.search(r"author\s*=\s*{(.*?)}", bibtex_string, re.DOTALL) | |
| title_match = re.search(r"title\s*=\s*{(.*?)}", bibtex_string, re.DOTALL) | |
| year_match = re.search(r"year\s*=\s*{(.*?)}", bibtex_string, re.DOTALL) | |
| journal_match = re.search(r"journal\s*=\s*{(.*?)}", bibtex_string, re.DOTALL) | |
| volume_match = re.search(r"volume\s*=\s*{(.*?)}", bibtex_string, re.DOTALL) | |
| number_match = re.search(r"number\s*=\s*{(.*?)}", bibtex_string, re.DOTALL) | |
| pages_match = re.search(r"pages\s*=\s*{(.*?)}", bibtex_string, re.DOTALL) | |
| author = author_match.group(1).strip() if author_match else "" | |
| title = title_match.group(1).strip() if title_match else "" | |
| year = year_match.group(1).strip() if year_match else "" | |
| journal = journal_match.group(1).strip() if journal_match else "" | |
| volume = volume_match.group(1).strip() if volume_match else "" | |
| number = number_match.group(1).strip() if number_match else "" | |
| pages = pages_match.group(1).strip() if pages_match else "" | |
| citation = f"{author}. {year}. {title}. {journal} {volume}({number}): {pages}." | |
| return citation | |
| def extract_cites(context): | |
| cites = [format_asa_citation(item['citation']) for item in context] | |
| cites = list(set(cites)) | |
| return '\n'.join([f'* {cite}' for cite in cites]) | |
| def generate_literature_review(thesis, style_notes, api_key, progress=gr.Progress()): | |
| yield gr.update(value="") | |
| output = [] | |
| # Check if using the special password "quote" to load from environment | |
| if api_key == "quote": | |
| api_key = os.getenv("ANTHROPIC_API_KEY") | |
| if not api_key: | |
| raise ValueError("Environment variable ANTHROPIC_API_KEY not found") | |
| # Initialize Anthropic client | |
| client = anthropic.Anthropic(api_key=api_key) | |
| # Load data | |
| progress(0.1, desc="Loading document chunks...") | |
| # First attempt | |
| progress(0.2, desc="Finding initial references...") | |
| context = search_similar(get_embeddings(thesis), chunk_df, k=8) | |
| first_cites = extract_cites(context) | |
| progress(0.4, desc="Generating first draft...") | |
| naive_results = naive_search(thesis, context, client) | |
| # Second attempt | |
| progress(0.6, desc="Finding additional references...") | |
| context2 = search_similar(get_embeddings(naive_results), chunk_df, k=16) | |
| text_chunks = [c['text'] for c in context] | |
| combo_context = context + [c for c in context2 if c['text'] not in text_chunks] | |
| final_cites = extract_cites(combo_context) | |
| ref_count = len(final_cites.split('\n')) | |
| progress(0.8, desc=f"Generating final draft from {ref_count} sources...") | |
| draft = section_drafter(thesis, combo_context, style_notes, client) | |
| output.append(draft) | |
| output.append("\n## Sources\n" + final_cites) | |
| progress(1.0, desc="Complete!") | |
| yield "\n\n".join(output) | |
| def create_interface(): | |
| theme = gr.themes.Soft( | |
| primary_hue="slate", | |
| ) | |
| with gr.Blocks(theme=theme) as app: | |
| gr.Markdown("# CiteCraft") | |
| intro_text = f"Using {chunk_count} pages from {source_count} sources." | |
| gr.Markdown(intro_text) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| thesis_input = gr.Textbox( | |
| label="Section Topic", | |
| placeholder="Enter your research question or section theme here", | |
| lines=4 | |
| ) | |
| style_notes_input = gr.Textbox( | |
| label="Style notes (Optional)", | |
| placeholder="Enter any writing style modifications (optional)", | |
| lines=4 | |
| ) | |
| api_key = gr.Textbox( | |
| label="Anthropic API Key", | |
| placeholder="Enter your Anthropic API key", | |
| type="password" | |
| ) | |
| generate_button = gr.Button("Generate Review", variant="primary") | |
| with gr.Column(scale=2): | |
| output_text = gr.Markdown(label="Generated Review") | |
| generate_button.click( | |
| generate_literature_review, | |
| inputs=[thesis_input, style_notes_input, api_key], | |
| outputs=output_text | |
| ) | |
| return app | |
| if __name__ == "__main__": | |
| # Load necessary prompts and configurations | |
| try: | |
| json_prompts = load_json_prompts('prompts.json') | |
| print("Successfully loaded prompts.json") | |
| # Get system prompts and protest file name | |
| section_draft_prompt = json_prompts['section_draft_prompt'] | |
| naive_system_prompt = json_prompts['naive_system_prompt'] | |
| protest_file_name = json_prompts['protest_file_name'] | |
| print(f"Will load document chunks from: {protest_file_name}") | |
| chunk_df = pd.read_json(protest_file_name) | |
| chunk_count = len(chunk_df) | |
| source_count = len(chunk_df['citation'].value_counts()) | |
| except Exception as e: | |
| print(f"Error loading prompts.json: {e}") | |
| raise | |
| # Launch the app | |
| app = create_interface() | |
| app.launch(share=True) # share=True creates a public URL | |