kora-synth / response_generator.py
LeonceNsh's picture
Upload folder using huggingface_hub
029af47 verified
import os
from typing import List, Dict, Any, Optional
import openai
from similarity_search import load_index, load_embedding_model, find_similar_chunks, beautify_text
import pandas as pd
try:
from dotenv import load_dotenv
load_dotenv()
except ImportError:
pass
class SyntheticDataGenerator:
"""
A class to generate synthetic data based on data descriptions
found in a research paper, using an LLM.
"""
def __init__(self,
openai_api_key: Optional[str] = None,
model_name: str = "google/gemini-2.0-flash-exp:free",
embedding_model_name: str = "all-MiniLM-L6-v2",
index_path: str = "faiss_index.index",
text_chunks_path: str = "text_chunks.pkl",
max_context_length: int = 8000,
metadata_context: Optional[str] = None):
"""
Initializes the SyntheticDataGenerator.
Args:
openai_api_key: OpenRouter API key.
model_name: The LLM to use for generation.
embedding_model_name: The model for creating text embeddings.
index_path: Path to the FAISS index of the paper's text.
text_chunks_path: Path to the pickled text chunks from the paper.
max_context_length: Maximum context length for the LLM prompt.
metadata_context: Optional metadata context to guide data generation.
"""
# Set up OpenRouter API key
if openai_api_key:
self.api_key = openai_api_key
elif os.getenv("OPENROUTER_API_KEY"):
self.api_key = os.getenv("OPENROUTER_API_KEY")
elif os.getenv("OPENAI_API_KEY"): # Fallback for compatibility
self.api_key = os.getenv("OPENAI_API_KEY")
else:
raise ValueError("OpenRouter API key must be provided either as parameter or OPENROUTER_API_KEY environment variable")
# Use OpenRouter API endpoint
self.client = openai.OpenAI(
api_key=self.api_key,
base_url="https://openrouter.ai/api/v1"
)
self.model_name = model_name
self.max_context_length = max_context_length
self.metadata_context = metadata_context
# Load embedding model and FAISS index
print("Loading embedding model and FAISS index...")
self.embedding_model = load_embedding_model(embedding_model_name)
self.index, self.text_chunks = load_index(index_path, text_chunks_path)
if self.index is None or self.text_chunks is None:
raise ValueError("Failed to load FAISS index or text chunks.")
print(f"Successfully loaded index with {len(self.text_chunks)} text chunks.")
def get_relevant_context(self, query: str, k: int = 10) -> str:
"""
Retrieves relevant text chunks from the paper to form a context.
Args:
query: The query to find relevant text for.
k: The number of chunks to retrieve.
Returns:
A string containing the formatted context.
"""
relevant_chunks = find_similar_chunks(
query, self.embedding_model, self.index, self.text_chunks, k=k
)
context_parts = []
total_length = 0
for i, chunk in enumerate(relevant_chunks):
chunk_text = beautify_text(chunk)
page_number = chunk.get("page", "N/A")
formatted_chunk = f"[Chunk {i+1}, Page {page_number}]: {chunk_text}"
if total_length + len(formatted_chunk) > self.max_context_length:
break
context_parts.append(formatted_chunk)
total_length += len(formatted_chunk)
return "\n\n".join(context_parts)
def generate_synthetic_data(self,
k: int = 10,
temperature: float = 0.7,
max_tokens: int = 4000) -> Dict[str, Any]:
"""
Generates synthetic data by prompting the LLM with context from the paper.
Args:
k: The number of context chunks to use.
temperature: The temperature for the LLM generation.
max_tokens: The maximum number of tokens for the LLM response.
Returns:
A dictionary containing the generated data and metadata.
"""
# A fixed query to find the data description sections of the paper
context_query = "Detailed description of dataset, features, and data analysis methodology."
context_string = self.get_relevant_context(context_query, k)
system_prompt = """
You are a helpful AI assistant that extracts the data part of the paper.
You are given a paper that uses data to claim findings.
Your job is to answer what features were used, what these features mean
and generate a sample dataset of 100 records of synthetic data for these features.
The output should be only the synthetic data in CSV format. Give an explanation on what features have been used in the paper,
what their distribution is.
"""
# Build user prompt with optional metadata context
metadata_section = self.metadata_context if self.metadata_context else ""
user_prompt = f"""Based on the following context from a research paper, please generate a synthetic dataset of 100 records.
Context from the paper:
{context_string}{metadata_section}
You have to keep the synthetic data with the similar distribution in all features.
The collinearity between the synthetic data should remain similar to what is mentioned in the paper.
The distribution of the categorical data should be consistent in the data used regarding if it is imbalanced or not.
The distribution of numerical data should be consistent with either uniform distribution or binomial distribution or normal distribution.
This should be applied to each and every feature. If there is a skew in the distribution, you should keep it as it is.
{"Please follow the expected data schema provided above when generating the synthetic data." if self.metadata_context else ""}
Give an explanation on what features have been used in the paper, what their distribution is.
Generate a sample dataset of 100 records of synthetic data for these features.
"""
try:
# Generate response using OpenRouter
response = self.client.chat.completions.create(
model=self.model_name,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
],
temperature=temperature,
max_tokens=max_tokens
)
generated_text = response.choices[0].message.content
return {
"response": generated_text,
"context_used": context_string,
"model_used": self.model_name,
}
except Exception as e:
return {
"response": f"Error generating response: {str(e)}",
"context_used": context_string,
"error": str(e)
}
def main():
"""
Main function to generate and save synthetic data.
"""
try:
# Initialize the generator
generator = SyntheticDataGenerator(
model_name="google/gemini-flash-1.5",
max_context_length=8000
)
print("Generator initialized. Generating synthetic data...")
# Generate data
result = generator.generate_synthetic_data(k=10)
if "error" in result:
print(f"An error occurred: {result['error']}")
else:
# Save the synthetic data to a CSV file
csv_data = result['response']
try:
# The LLM is prompted to return only CSV data
with open("synthetic_data.csv", "w") as f:
f.write(csv_data)
print("Synthetic data saved to synthetic_data.csv")
# Optional: Load into pandas to verify and display
df = pd.read_csv("synthetic_data.csv")
print("\nFirst 5 rows of the synthetic dataset:")
print(df.head())
except Exception as e:
print(f"Could not process the response as CSV: {e}")
print("Here is the raw response:")
print(csv_data)
except Exception as e:
print(f"Error initializing or running the generator: {e}")
print("Please ensure your environment is set up correctly.")
if __name__ == "__main__":
main()