kora-synth / gradio_app.py
LeonceNsh's picture
Upload folder using huggingface_hub
a6a1b5a verified
import os
import tempfile
import pickle
import json
from io import StringIO
from typing import Tuple, Optional
import gradio as gr
import faiss
import pandas as pd
from pdf_parser import pdf_parser
from generate_embedding import (
load_embedding_model as load_embed_model,
generate_embedding,
generate_faiss_index,
)
from response_generator import SyntheticDataGenerator
from synthetic_data_scaling import scale_csv_text
def extract_csv_text(raw_text: str) -> str:
"""Best-effort extraction of CSV content from an LLM response."""
text = raw_text.strip()
if "```" in text:
parts = text.split("```")
for i in range(len(parts) - 1):
if parts[i].lower().strip().endswith("csv"):
return parts[i + 1].strip()
if len(parts) >= 2:
return parts[1].strip()
return text
def split_explanation_and_csv(raw_text: str) -> Tuple[str, str]:
"""Split the LLM response into (explanation_text, csv_text)."""
text = raw_text or ""
lines = text.splitlines()
csv_start = None
csv_end = None
for i, ln in enumerate(lines):
s = ln.strip()
if csv_start is None and s.lower() == "```csv":
csv_start = i + 1
continue
if csv_start is not None and s == "```":
csv_end = i
break
if csv_start is None or csv_end is None:
start = None
end = None
for i, ln in enumerate(lines):
if ln.strip().startswith("```"):
start = i + 1
break
if start is not None:
for j in range(start, len(lines)):
if lines[j].strip() == "```":
end = j
break
if start is not None and end is not None:
csv_start, csv_end = start, end
if csv_start is not None and csv_end is not None and 0 <= csv_start <= csv_end <= len(lines):
csv_lines = lines[csv_start:csv_end]
explanation_lines = lines[:csv_start - 1] + lines[csv_end + 1:]
explanation_text = "\n".join(explanation_lines).strip()
csv_text = "\n".join(csv_lines).strip()
return explanation_text, csv_text
return "", text.strip()
def parse_metadata_file(metadata_file) -> Optional[str]:
"""Parse uploaded metadata file and format for LLM prompt."""
if not metadata_file:
return None
try:
with open(metadata_file.name, 'r') as f:
content = f.read().strip()
# Try to parse as JSON first
try:
metadata = json.loads(content)
if isinstance(metadata, dict):
# Format as structured metadata prompt
metadata_prompt = "\n\nExpected Data Schema:\n"
for column, info in metadata.items():
if isinstance(info, dict):
col_type = info.get('type', 'unknown')
description = info.get('description', '')
metadata_prompt += f"- {column}: {col_type}"
if description:
metadata_prompt += f" - {description}"
metadata_prompt += "\n"
else:
metadata_prompt += f"- {column}: {info}\n"
return metadata_prompt
except json.JSONDecodeError:
pass
# If not JSON, treat as plain text metadata
return f"\n\nExpected Data Schema:\n{content}\n"
except Exception as e:
print(f"Error parsing metadata file: {e}")
return None
class SyntheticDataApp:
def __init__(self):
self.sample_df = None
self.sample_csv_bytes = None
self.explanation_text = None
self.scaled_csv_bytes = None
def process_pdf_and_generate_sample(
self,
pdf_file,
metadata_file,
llama_key: str,
openrouter_key: str,
model_name: str = "google/gemini-flash-1.5",
k_chunks: int = 10,
progress=gr.Progress()
):
"""Process PDF and generate sample CSV data."""
if not pdf_file:
return "❌ Please upload a PDF file first.", None, None, None
if not llama_key or not openrouter_key:
return "❌ Please provide both LLAMA_CLOUD_API_KEY and OPENROUTER_API_KEY.", None, None, None
try:
# Set API keys
os.environ["LLAMA_CLOUD_API_KEY"] = llama_key
os.environ["OPENROUTER_API_KEY"] = openrouter_key
progress(0.1, desc="Parsing PDF...")
pdf_content = pdf_parser(pdf_file)
progress(0.3, desc="Generating embeddings...")
embed_model = load_embed_model("all-MiniLM-L6-v2")
embeddings = generate_embedding(pdf_content, embed_model)
index = generate_faiss_index(embeddings)
progress(0.6, desc="Generating synthetic data...")
# Parse metadata if provided
metadata_prompt = parse_metadata_file(metadata_file)
with tempfile.TemporaryDirectory() as tmpdir:
index_path = os.path.join(tmpdir, "faiss_index.index")
chunks_path = os.path.join(tmpdir, "text_chunks.pkl")
faiss.write_index(index, index_path)
with open(chunks_path, "wb") as f:
pickle.dump(pdf_content, f)
generator = SyntheticDataGenerator(
openai_api_key=openrouter_key,
model_name=model_name,
index_path=index_path,
text_chunks_path=chunks_path,
max_context_length=8000,
metadata_context=metadata_prompt
)
result = generator.generate_synthetic_data(k=int(k_chunks))
raw_response = result.get("response", "")
explanation_text, csv_text = split_explanation_and_csv(raw_response)
# Clean and parse CSV
csv_text_clean = "\n".join([
ln for ln in csv_text.splitlines()
if ln.strip() not in ('csv', 'CSV', '```', '```csv', '```CSV')
])
df_sample = pd.read_csv(StringIO(csv_text_clean), comment='`')
df_sample.rename(columns=lambda c: str(c).strip(), inplace=True)
if 'csv' in df_sample.columns or 'CSV' in df_sample.columns:
df_sample = df_sample.drop(columns=[
c for c in ('csv', 'CSV') if c in df_sample.columns
])
sample_csv_bytes = df_sample.to_csv(index=False).encode()
# Store results
self.sample_df = df_sample
self.sample_csv_bytes = sample_csv_bytes
self.explanation_text = explanation_text
progress(1.0, desc="Complete!")
return (
f"✅ Successfully generated sample data with {len(df_sample)} rows and {len(df_sample.columns)} columns.",
df_sample,
explanation_text or "No explanation provided by the model.",
sample_csv_bytes
)
except Exception as e:
return f"❌ Error: {str(e)}", None, None, None
def scale_data(
self,
dbtwin_key: str,
scale_rows: int = 1000,
progress=gr.Progress()
):
"""Scale the generated sample data using DBTwin API."""
if self.sample_df is None:
return "❌ Please generate sample data first.", None, None
if not dbtwin_key:
return "❌ Please provide DBTWIN_API_KEY.", None, None
try:
progress(0.2, desc="Scaling data via DBTwin API...")
os.environ["DBTWIN_API_KEY"] = dbtwin_key
scaled_df, scaled_bytes, headers = scale_csv_text(
self.sample_df.to_csv(index=False),
dbtwin_key,
rows=int(scale_rows),
algo="flagship",
)
self.scaled_csv_bytes = scaled_bytes
# Get quality metrics
dist_err = headers.get("distribution-similarity-error") if headers else "N/A"
assoc_sim = headers.get("association-similarity") if headers else "N/A"
progress(1.0, desc="Scaling complete!")
metrics_info = f"📊 **Quality Metrics:**\n- Distribution Similarity Error: {dist_err}\n- Association Similarity: {assoc_sim}"
return (
f"✅ Successfully scaled data to {len(scaled_df)} rows.\n\n{metrics_info}",
scaled_df,
scaled_bytes
)
except Exception as e:
return f"❌ Scaling failed: {str(e)}", None, None
def create_interface():
"""Create the Gradio interface."""
app = SyntheticDataApp()
with gr.Blocks(
title="Healthcare Paper → Synthetic Data Generator",
theme=gr.themes.Soft(),
css="""
.main-header {
text-align: center;
background: linear-gradient(90deg, #667eea 0%, #764ba2 100%);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
font-size: 2.5em;
font-weight: bold;
margin-bottom: 1em;
}
.step-header {
background-color: #f8f9fa;
padding: 1em;
border-radius: 10px;
border-left: 4px solid #667eea;
margin: 1em 0;
}
.info-box {
background-color: #e3f2fd;
padding: 1em;
border-radius: 8px;
border: 1px solid #2196f3;
margin: 0.5em 0;
}
"""
) as demo:
gr.HTML("""
<div class="main-header">
🔬 Healthcare Paper → Synthetic Data Generator
</div>
<div class="info-box">
<h3>Enterprise-Ready Synthetic Data Pipeline</h3>
<p>Transform research papers into high-quality synthetic datasets in three simple steps:</p>
<ul>
<li><b>Step 1:</b> Upload your research paper (PDF format)</li>
<li><b>Step 2:</b> Generate sample synthetic data using AI</li>
<li><b>Step 3:</b> Scale to production-size datasets with DBTwin</li>
</ul>
</div>
""")
# Step 1: Configuration and PDF Upload
with gr.Group():
gr.HTML('<div class="step-header"><h2>📋 Step 1: Configuration & Upload</h2></div>')
with gr.Row():
with gr.Column(scale=2):
pdf_file = gr.File(
label="📄 Upload Research Paper (PDF)",
file_types=[".pdf"]
)
metadata_file = gr.File(
label="📋 Upload Data Schema/Metadata (Optional)",
file_types=[".json", ".txt", ".md"]
)
gr.HTML("""
<div style="background-color: #e8f4fd; padding: 0.8em; border-radius: 6px; border: 1px solid #b3d9ff; margin-top: 0.5em;">
<h5>📋 Metadata Format Examples:</h5>
<p><b>JSON format:</b></p>
<pre style="font-size: 0.8em; background-color: #f8f9fa; padding: 0.5em; border-radius: 4px;">
{
"age": {"type": "integer", "description": "Patient age in years"},
"gender": {"type": "categorical", "description": "Male/Female"},
"blood_pressure": {"type": "float", "description": "Systolic BP in mmHg"}
}</pre>
<p><b>Text format:</b> Simply describe your expected columns and their types.</p>
</div>
""")
with gr.Column(scale=1):
gr.HTML("""
<div style="background-color: #fff3cd; padding: 1em; border-radius: 8px; border: 1px solid #ffeaa7;">
<h4>💡 Tips for Best Results:</h4>
<ul>
<li>Upload healthcare/medical research papers</li>
<li>Ensure PDF contains tables or data descriptions</li>
<li>Clear text (not scanned images) works best</li>
<li>Upload metadata to specify expected column types</li>
</ul>
</div>
""")
with gr.Accordion("⚙️ API Configuration", open=False):
with gr.Row():
llama_key = gr.Textbox(
label="🔑 LLAMA_CLOUD_API_KEY",
type="password",
placeholder="Enter your LlamaParse API key",
info="Required for PDF parsing"
)
openrouter_key = gr.Textbox(
label="🔑 OPENROUTER_API_KEY",
type="password",
placeholder="Enter your OpenRouter API key",
info="Required for LLM data generation"
)
with gr.Row():
model_name = gr.Textbox(
label="🤖 LLM Model",
value="google/gemini-flash-1.5",
info="OpenRouter model name"
)
k_chunks = gr.Slider(
label="📊 Context Chunks",
minimum=3,
maximum=30,
value=10,
step=1,
info="Number of relevant document chunks to use"
)
# Step 2: Generate Sample Data
with gr.Group():
gr.HTML('<div class="step-header"><h2>🎯 Step 2: Generate Sample Data</h2></div>')
generate_btn = gr.Button(
"Generate Sample Synthetic Data",
variant="primary",
size="lg",
scale=2
)
generation_status = gr.Textbox(
label="Status",
interactive=False,
max_lines=3
)
with gr.Row():
with gr.Column():
sample_data_preview = gr.Dataframe(
label="Sample Data Preview",
wrap=True
)
sample_download = gr.File(
label="Download Sample CSV",
visible=False
)
with gr.Column():
explanation_output = gr.Markdown(
label="Feature Explanation",
value="*Feature descriptions will appear here after generation*"
)
# Step 3: Scale Data
with gr.Group():
gr.HTML('<div class="step-header"><h2>Step 3: Scale Data (Optional)</h2></div>')
with gr.Row():
with gr.Column():
dbtwin_key = gr.Textbox(
label="🔑 DBTWIN_API_KEY",
type="password",
placeholder="Enter your DBTwin API key (optional)",
info="Required for data scaling"
)
scale_rows = gr.Slider(
label="Target Dataset Size",
minimum=100,
maximum=200000,
value=1000,
step=100,
info="Number of rows in scaled dataset"
)
scale_btn = gr.Button(
"Scale Dataset with DBTwin",
variant="secondary",
size="lg"
)
with gr.Column():
gr.HTML("""
<div style="background-color: #d1ecf1; padding: 1em; border-radius: 8px; border: 1px solid #bee5eb;">
<h4>About Data Scaling:</h4>
<ul>
<li><b>DBTwin:</b> Professional synthetic data generation</li>
<li><b>Maintains:</b> Statistical properties and correlations</li>
<li><b>Quality:</b> Enterprise-grade data quality metrics</li>
<li><b>Scale:</b> Up to 200,000 rows</li>
</ul>
</div>
""")
scaling_status = gr.Textbox(
label="Scaling Status",
interactive=False,
max_lines=4
)
with gr.Row():
scaled_data_preview = gr.Dataframe(
label="Scaled Data Preview",
wrap=True
)
scaled_download = gr.File(
label="Download Scaled CSV",
visible=False
)
# Footer
gr.HTML("""
<div style="text-align: center; margin-top: 2em; padding: 1em; background-color: #f8f9fa; border-radius: 10px;">
<p><strong>Enterprise Synthetic Data Pipeline</strong> |
Built for Healthcare Research & Data Science Teams</p>
<p>Need help? Contact your data engineering team or check the documentation.</p>
</div>
""")
# Event handlers
generate_btn.click(
fn=app.process_pdf_and_generate_sample,
inputs=[pdf_file, metadata_file, llama_key, openrouter_key, model_name, k_chunks],
outputs=[generation_status, sample_data_preview, explanation_output, sample_download],
show_progress=True
).then(
fn=lambda x: gr.File(value=x, visible=True) if x else gr.File(visible=False),
inputs=[sample_download],
outputs=[sample_download]
)
scale_btn.click(
fn=app.scale_data,
inputs=[dbtwin_key, scale_rows],
outputs=[scaling_status, scaled_data_preview, scaled_download],
show_progress=True
).then(
fn=lambda x: gr.File(value=x, visible=True) if x else gr.File(visible=False),
inputs=[scaled_download],
outputs=[scaled_download]
)
return demo
if __name__ == "__main__":
demo = create_interface()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=True,
show_error=True
)