rags_api / generate_testset.py
Skier8402's picture
Upload generate_testset.py
aaf767a verified
"""
Testset Generation module for the RAG system using Ragas.
This script generates question-answer pairs from documents to be used in evaluation.
how to run:
python generate_testset.py <pdf_path1> <pdf_path2> ...
"""
# pylint: disable=import-error,no-name-in-module,invalid-name,broad-except,missing-function-docstring,missing-class-docstring,too-many-return-statements,ungrouped-imports,line-too-long,logging-fstring-interpolation,duplicate-code,too-few-public-methods
import os
import sys
import logging
from typing import List, Any
from PyPDF2 import PdfReader
try:
# Newer langchain versions expose Document in langchain.schema
from langchain.schema import Document
except Exception:
try:
# Older versions used langchain.docstore.document
from langchain.docstore.document import Document
except Exception:
# Minimal fallback Document for environments without langchain
from dataclasses import dataclass
@dataclass
class Document:
page_content: str
metadata: dict | None = None
from ragas.testset.synthesizers.generate import TestsetGenerator
try:
from langchain.chat_models import ChatOpenAI
except Exception:
from langchain_openai import ChatOpenAI
try:
from langchain_huggingface import HuggingFaceEmbeddings
except Exception:
from langchain_community.embeddings import HuggingFaceEmbeddings
try:
from langchain.schema import SystemMessage, HumanMessage
except Exception:
# Minimal stand-ins if langchain.schema isn't available
from dataclasses import dataclass
@dataclass
class SystemMessage:
content: str
@dataclass
class HumanMessage:
content: str
def _extract_chat_response(resp) -> str:
"""Robust extraction of text from various ChatOpenAI response shapes."""
try:
# langchain newer: AIMessage with .content
if hasattr(resp, "content"):
return resp.content
# langchain older/other: ChatResult with .generations
if hasattr(resp, "generations"):
gens = resp.generations
# gens may be list[list[Generation]] or list[Generation]
try:
return gens[0][0].text
except Exception:
try:
return gens[0].text
except Exception:
pass
# fallback dict/list shapes
if isinstance(resp, list) and resp:
first = resp[0]
if hasattr(first, "content"):
return first.content
if isinstance(first, dict) and "content" in first:
return first["content"]
if isinstance(resp, dict):
for k in ("content", "text"):
if k in resp:
return resp[k]
except Exception:
pass
return str(resp)
def summarize_documents(docs, llm, max_summary_chars: int = 2000) -> List[Document]:
"""Summarize each Document using the provided LLM into shorter Documents.
This is optional and controlled by the `USE_CHUNK_SUMMARIZATION` env var.
"""
summaries: List[Document] = []
for i, doc in enumerate(docs):
text = (doc.page_content or "").strip()
if not text:
continue
# Construct a concise summarization prompt
prompt = (
f"Summarize the following text into a concise summary (preserve key facts, numbers, and named entities). "
f"Aim for no more than {max_summary_chars} characters. Return only the summary, no commentary.\n\nText:\n"
+ text
)
try:
messages = [
SystemMessage(content="You are a concise summarizer."),
HumanMessage(content=prompt),
]
resp = llm(messages)
summary = _extract_chat_response(resp)
except Exception:
try:
resp = llm(prompt)
summary = _extract_chat_response(resp)
except Exception as e:
logging.debug(f"Summarization failed for chunk {i}: {e}")
# Fallback: truncate
summary = text[:max_summary_chars]
summary = (summary or "").strip()
if not summary:
summary = text[:max_summary_chars]
meta = dict(doc.metadata) if getattr(doc, "metadata", None) else {}
meta.update({"chunk": i})
summaries.append(Document(page_content=summary, metadata=meta))
return summaries
# Text splitting to avoid sending huge prompts to the LLM
try:
from langchain.text_splitter import RecursiveCharacterTextSplitter
except Exception:
# Minimal fallback splitter if langchain isn't available
class RecursiveCharacterTextSplitter:
def __init__(self, chunk_size: int = 8000, chunk_overlap: int = 500):
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
def split_documents(self, docs):
out = []
for doc in docs:
text = doc.page_content or ""
step = max(1, self.chunk_size - self.chunk_overlap)
for i in range(0, len(text), step):
chunk = text[i : i + self.chunk_size]
out.append(Document(page_content=chunk, metadata=doc.metadata))
return out
def get_documents_from_pdfs(pdf_paths: List[str]) -> List[Document]:
"""
Load PDFs and convert them to LangChain Document objects.
Parameters
----------
pdf_paths : List[str]
List of paths to PDF files.
Returns
-------
List[Document]
List of LangChain Document objects.
"""
documents = []
for path in pdf_paths:
try:
reader = PdfReader(path)
text = ""
for page in reader.pages:
page_text = page.extract_text()
if page_text:
text += page_text
source = os.path.basename(path)
documents.append(Document(page_content=text, metadata={"source": source}))
except Exception as e:
logging.error(f"Error reading {path}: {e}")
return documents
def generate_testset(
pdf_paths: List[str], test_size: int = 10, output_path: str = "testset.csv"
) -> Any:
"""
Generate a test set from the given PDFs.
Parameters
----------
pdf_paths : List[str]
List of paths to PDF files.
test_size : int, optional
Number of QA pairs to generate.
output_path : str, optional
Path to save the generated test set (CSV).
Returns
-------
Any
The generated test set.
"""
documents = get_documents_from_pdfs(pdf_paths)
if not documents:
logging.error("No documents found to generate testset from.")
return None
# Configure LLM and Embeddings consistent with the app
# Use environment variables for API keys and Base URL (e.g. standard OPENAI_*, or manually set)
# Allow overriding the LLM model via env var
model_name = os.getenv("TESTSET_LLM_MODEL", "openai/gpt-4o-mini")
logging.info(f"Using LLM model: {model_name}")
# Prefer OpenRouter when available so generated LLM clients use it by default.
_openrouter_key = os.getenv("OPENROUTER_API_KEY")
if _openrouter_key:
os.environ["OPENAI_API_BASE"] = "https://api.openrouter.ai/v1"
os.environ["OPENAI_API_KEY"] = _openrouter_key
logging.info(
"OpenRouter detected; routing OpenAI calls via %s",
os.environ["OPENAI_API_BASE"],
)
logging.info(
"OPENAI_API_KEY loaded=%s",
bool(os.environ.get("OPENAI_API_KEY")),
)
# Create LLM clients (will read credentials from environment)
generator_llm = ChatOpenAI(model=model_name)
# Note: critic_llm would be used for test evaluation if needed in future
# critic_llm = ChatOpenAI(model=model_name)
embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-base-en-v1.5")
# Initialize generator (provide the generator LLM and the embeddings)
generator = TestsetGenerator.from_langchain(generator_llm, embeddings)
# Split large documents into smaller chunks to avoid exceeding model context limits
splitter = RecursiveCharacterTextSplitter(chunk_size=8000, chunk_overlap=500)
split_docs = splitter.split_documents(documents)
# Generate testset (use default query distribution)
logging.info(
f"Generating testset of size {test_size} from {len(split_docs)} chunks..."
)
testset = generator.generate_with_langchain_docs(split_docs, testset_size=test_size)
# Export to CSV
test_df = testset.to_pandas()
test_df.to_csv(output_path, index=False)
logging.info(f"Testset saved to {output_path}")
return testset
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
if len(sys.argv) < 2:
print("Usage: python generate_testset.py <pdf_path1> <pdf_path2> ...")
else:
pdf_files = sys.argv[1:]
generate_testset(pdf_files)