Spaces:
Running
Running
File size: 9,430 Bytes
aaf767a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 | """
Testset Generation module for the RAG system using Ragas.
This script generates question-answer pairs from documents to be used in evaluation.
how to run:
python generate_testset.py <pdf_path1> <pdf_path2> ...
"""
# pylint: disable=import-error,no-name-in-module,invalid-name,broad-except,missing-function-docstring,missing-class-docstring,too-many-return-statements,ungrouped-imports,line-too-long,logging-fstring-interpolation,duplicate-code,too-few-public-methods
import os
import sys
import logging
from typing import List, Any
from PyPDF2 import PdfReader
try:
# Newer langchain versions expose Document in langchain.schema
from langchain.schema import Document
except Exception:
try:
# Older versions used langchain.docstore.document
from langchain.docstore.document import Document
except Exception:
# Minimal fallback Document for environments without langchain
from dataclasses import dataclass
@dataclass
class Document:
page_content: str
metadata: dict | None = None
from ragas.testset.synthesizers.generate import TestsetGenerator
try:
from langchain.chat_models import ChatOpenAI
except Exception:
from langchain_openai import ChatOpenAI
try:
from langchain_huggingface import HuggingFaceEmbeddings
except Exception:
from langchain_community.embeddings import HuggingFaceEmbeddings
try:
from langchain.schema import SystemMessage, HumanMessage
except Exception:
# Minimal stand-ins if langchain.schema isn't available
from dataclasses import dataclass
@dataclass
class SystemMessage:
content: str
@dataclass
class HumanMessage:
content: str
def _extract_chat_response(resp) -> str:
"""Robust extraction of text from various ChatOpenAI response shapes."""
try:
# langchain newer: AIMessage with .content
if hasattr(resp, "content"):
return resp.content
# langchain older/other: ChatResult with .generations
if hasattr(resp, "generations"):
gens = resp.generations
# gens may be list[list[Generation]] or list[Generation]
try:
return gens[0][0].text
except Exception:
try:
return gens[0].text
except Exception:
pass
# fallback dict/list shapes
if isinstance(resp, list) and resp:
first = resp[0]
if hasattr(first, "content"):
return first.content
if isinstance(first, dict) and "content" in first:
return first["content"]
if isinstance(resp, dict):
for k in ("content", "text"):
if k in resp:
return resp[k]
except Exception:
pass
return str(resp)
def summarize_documents(docs, llm, max_summary_chars: int = 2000) -> List[Document]:
"""Summarize each Document using the provided LLM into shorter Documents.
This is optional and controlled by the `USE_CHUNK_SUMMARIZATION` env var.
"""
summaries: List[Document] = []
for i, doc in enumerate(docs):
text = (doc.page_content or "").strip()
if not text:
continue
# Construct a concise summarization prompt
prompt = (
f"Summarize the following text into a concise summary (preserve key facts, numbers, and named entities). "
f"Aim for no more than {max_summary_chars} characters. Return only the summary, no commentary.\n\nText:\n"
+ text
)
try:
messages = [
SystemMessage(content="You are a concise summarizer."),
HumanMessage(content=prompt),
]
resp = llm(messages)
summary = _extract_chat_response(resp)
except Exception:
try:
resp = llm(prompt)
summary = _extract_chat_response(resp)
except Exception as e:
logging.debug(f"Summarization failed for chunk {i}: {e}")
# Fallback: truncate
summary = text[:max_summary_chars]
summary = (summary or "").strip()
if not summary:
summary = text[:max_summary_chars]
meta = dict(doc.metadata) if getattr(doc, "metadata", None) else {}
meta.update({"chunk": i})
summaries.append(Document(page_content=summary, metadata=meta))
return summaries
# Text splitting to avoid sending huge prompts to the LLM
try:
from langchain.text_splitter import RecursiveCharacterTextSplitter
except Exception:
# Minimal fallback splitter if langchain isn't available
class RecursiveCharacterTextSplitter:
def __init__(self, chunk_size: int = 8000, chunk_overlap: int = 500):
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
def split_documents(self, docs):
out = []
for doc in docs:
text = doc.page_content or ""
step = max(1, self.chunk_size - self.chunk_overlap)
for i in range(0, len(text), step):
chunk = text[i : i + self.chunk_size]
out.append(Document(page_content=chunk, metadata=doc.metadata))
return out
def get_documents_from_pdfs(pdf_paths: List[str]) -> List[Document]:
"""
Load PDFs and convert them to LangChain Document objects.
Parameters
----------
pdf_paths : List[str]
List of paths to PDF files.
Returns
-------
List[Document]
List of LangChain Document objects.
"""
documents = []
for path in pdf_paths:
try:
reader = PdfReader(path)
text = ""
for page in reader.pages:
page_text = page.extract_text()
if page_text:
text += page_text
source = os.path.basename(path)
documents.append(Document(page_content=text, metadata={"source": source}))
except Exception as e:
logging.error(f"Error reading {path}: {e}")
return documents
def generate_testset(
pdf_paths: List[str], test_size: int = 10, output_path: str = "testset.csv"
) -> Any:
"""
Generate a test set from the given PDFs.
Parameters
----------
pdf_paths : List[str]
List of paths to PDF files.
test_size : int, optional
Number of QA pairs to generate.
output_path : str, optional
Path to save the generated test set (CSV).
Returns
-------
Any
The generated test set.
"""
documents = get_documents_from_pdfs(pdf_paths)
if not documents:
logging.error("No documents found to generate testset from.")
return None
# Configure LLM and Embeddings consistent with the app
# Use environment variables for API keys and Base URL (e.g. standard OPENAI_*, or manually set)
# Allow overriding the LLM model via env var
model_name = os.getenv("TESTSET_LLM_MODEL", "openai/gpt-4o-mini")
logging.info(f"Using LLM model: {model_name}")
# Prefer OpenRouter when available so generated LLM clients use it by default.
_openrouter_key = os.getenv("OPENROUTER_API_KEY")
if _openrouter_key:
os.environ["OPENAI_API_BASE"] = "https://api.openrouter.ai/v1"
os.environ["OPENAI_API_KEY"] = _openrouter_key
logging.info(
"OpenRouter detected; routing OpenAI calls via %s",
os.environ["OPENAI_API_BASE"],
)
logging.info(
"OPENAI_API_KEY loaded=%s",
bool(os.environ.get("OPENAI_API_KEY")),
)
# Create LLM clients (will read credentials from environment)
generator_llm = ChatOpenAI(model=model_name)
# Note: critic_llm would be used for test evaluation if needed in future
# critic_llm = ChatOpenAI(model=model_name)
embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-base-en-v1.5")
# Initialize generator (provide the generator LLM and the embeddings)
generator = TestsetGenerator.from_langchain(generator_llm, embeddings)
# Split large documents into smaller chunks to avoid exceeding model context limits
splitter = RecursiveCharacterTextSplitter(chunk_size=8000, chunk_overlap=500)
split_docs = splitter.split_documents(documents)
# Generate testset (use default query distribution)
logging.info(
f"Generating testset of size {test_size} from {len(split_docs)} chunks..."
)
testset = generator.generate_with_langchain_docs(split_docs, testset_size=test_size)
# Export to CSV
test_df = testset.to_pandas()
test_df.to_csv(output_path, index=False)
logging.info(f"Testset saved to {output_path}")
return testset
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
if len(sys.argv) < 2:
print("Usage: python generate_testset.py <pdf_path1> <pdf_path2> ...")
else:
pdf_files = sys.argv[1:]
generate_testset(pdf_files)
|