Research-PDFs-VQA / backend.py
sugiv's picture
Adding initial set of files
452c0e2
import requests
import numpy as np
from openai import OpenAI
import json
from pinecone import Pinecone
import os
import requests
import numpy as np
from PIL import Image as PILImage
from docarray import BaseDoc
from docarray import DocList
from docarray.typing import ImageTensor, NdArray
from typing import List, Dict, Optional
import requests
import base64
from pdf_classes import RichPDFDocument
import io
# Access environment variables individually and pass them as separate arguments
pc = Pinecone(
api_key=os.environ["PINECONE_API_KEY"],
environment=os.environ["PINECONE_ENVIRONMENT"]
)
print("Connected to Pinecone")
index = pc.Index("rich-pdf-late-chunks")
def create_question_embedding(question: str, api_token: str) -> np.ndarray:
url = 'https://api.jina.ai/v1/embeddings'
headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {api_token}'
}
data = {
"model": "jina-clip-v1",
"input": [{"text": question}]
}
response = requests.post(url, headers=headers, json=data)
if response.status_code == 200:
result = response.json()
return np.array(result['data'][0]['embedding'])
else:
raise Exception(f"Error creating embedding: {response.text}")
from openai import OpenAI
def create_few_shot_prompt(question: str, rich_pdf: RichPDFDocument, pinecone_index, api_token: str, top_k: int = 3):
prompt = f"Question: {question}\n\n"
prompt += "Here are relevant excerpts from the document:\n\n"
image_data = []
included_pages = set()
question_embedding = create_question_embedding(question, api_token)
results = pinecone_index.query(vector=question_embedding.tolist(), top_k=top_k, include_metadata=True)
for i, match in enumerate(results['matches'], 1):
metadata = match['metadata']
#print(f"Processing chunk {i}: {metadata}")
segment_types = metadata['segment_types'].split(',')
page_numbers = [int(pn) for pn in metadata['page_numbers'].split(',')]
# Handle potential JSON decoding errors
try:
contents = json.loads(metadata['contents'])
except json.JSONDecodeError:
contents = [metadata['contents']] # Treat as a single content item if JSON decoding fails
prompt += f"Excerpt {i}:\n"
prompt += f"Pages: {', '.join(map(str, page_numbers))}\n"
prompt += f"Types: {', '.join(segment_types)}\n"
for j, content in enumerate(contents, 1):
if isinstance(content, str) and '[Image' in content:
prompt += f"Image content {j}: {content}\n"
else:
prompt += f"Text content {j}: {content[:200]}...\n" # Limit text content to 200 characters
prompt += "\n"
# Add only one full-page screenshot as a reference
if not included_pages and page_numbers:
page_num = page_numbers[0]
prompt += f"\nFull-page context for Page {page_num + 1}: [Full-page screenshot]\n"
buffered = io.BytesIO()
PILImage.fromarray(rich_pdf.pages[page_num].screenshot).save(buffered, format="PNG")
img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
image_data.append({
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{img_base64}",
"detail": "low"
}
})
included_pages.add(page_num)
#print(f"Added full-page screenshot of page {page_num + 1} to image_data")
prompt += "\nInstructions for answering the question:\n"
prompt += "1. Carefully review all provided excerpts.\n"
prompt += "2. Use the full-page screenshot to understand the overall context.\n"
prompt += "3. Refer to specific excerpts in your answer when applicable.\n"
prompt += "4. If the question asks for specific information, provide a clear and concise answer.\n"
prompt += "5. If the answer isn't directly stated, use the context to infer the most likely answer.\n\n"
prompt += f"Now, please answer the following question based on the provided information:\n{question}\n"
#print(f"\nTotal images included: {len(image_data)}")
return prompt, image_data
def query_gpt4o(question: str, rich_pdf, pinecone_index, api_token: str, gpt4_api_key: str):
client = OpenAI(api_key=gpt4_api_key)
prompt, image_data = create_few_shot_prompt(question, rich_pdf, pinecone_index, api_token)
#print("Prompt generated is:", prompt)
#print(f"\nNumber of images included: {len(image_data)}")
content_list = [{"type": "text", "text": prompt}] + image_data
#print(f"Total number of content items (text + images): {len(content_list)}")
try:
response = client.chat.completions.create(
model="gpt-4o", # Ensure this is the correct model name for GPT-4 with vision capabilities
messages=[
{
"role": "system",
"content": "You are an advanced AI assistant capable of analyzing various types of documents, including but not limited to research papers, financial reports, and general texts. Your task is to provide accurate and relevant answers to questions by carefully examining both textual and visual information provided from the document. When appropriate, cite specific excerpts, images, or page numbers in your responses. Explain your reasoning clearly, especially when making inferences or connections between different parts of the document."
},
{
"role": "user",
"content": content_list
}
],
max_tokens=500 # Increased token limit for more detailed responses
)
return response.choices[0].message.content
except Exception as e:
print(f"Failed to execute GPT-4V query: {e}")
return None
import re
def format_answer(answer):
# Convert LaTeX-style math to Markdown-style math
answer = re.sub(r'\$\$(.*?)\$\$', r'\\[\1\\]', answer)
answer = re.sub(r'\$(.*?)\$', r'\\(\1\\)', answer)
# Format headers
lines = answer.split('\n')
formatted_lines = []
for line in lines:
if line.startswith('###'):
formatted_lines.append(f"\n{line}\n")
elif line.startswith('**') and line.endswith('**'):
formatted_lines.append(f"\n{line}\n")
else:
formatted_lines.append(line)
# Join lines back together
formatted_answer = '\n'.join(formatted_lines)
# Add spacing around math blocks
formatted_answer = re.sub(r'(\\\\.*?\\\\)', r'\n\1\n', formatted_answer)
return formatted_answer
# Example usage function
def get_answer(question: str, enriched_pdf:RichPDFDocument, jina_api_token: str, gpt4_api_key: str):
answer_generated = query_gpt4o(question, enriched_pdf, index, jina_api_token, gpt4_api_key)
#print(answer_generated)
if answer_generated:
return format_answer(answer_generated)
else:
return None