Spaces:
Running on CPU Upgrade
Running on CPU Upgrade
| import os | |
| import tempfile | |
| import pickle | |
| import json | |
| from io import StringIO | |
| from typing import Tuple, Optional | |
| import gradio as gr | |
| import faiss | |
| import pandas as pd | |
| from pdf_parser import pdf_parser | |
| from generate_embedding import ( | |
| load_embedding_model as load_embed_model, | |
| generate_embedding, | |
| generate_faiss_index, | |
| ) | |
| from response_generator import SyntheticDataGenerator | |
| from synthetic_data_scaling import scale_csv_text | |
| def extract_csv_text(raw_text: str) -> str: | |
| """Best-effort extraction of CSV content from an LLM response.""" | |
| text = raw_text.strip() | |
| if "```" in text: | |
| parts = text.split("```") | |
| for i in range(len(parts) - 1): | |
| if parts[i].lower().strip().endswith("csv"): | |
| return parts[i + 1].strip() | |
| if len(parts) >= 2: | |
| return parts[1].strip() | |
| return text | |
| def split_explanation_and_csv(raw_text: str) -> Tuple[str, str]: | |
| """Split the LLM response into (explanation_text, csv_text).""" | |
| text = raw_text or "" | |
| lines = text.splitlines() | |
| csv_start = None | |
| csv_end = None | |
| for i, ln in enumerate(lines): | |
| s = ln.strip() | |
| if csv_start is None and s.lower() == "```csv": | |
| csv_start = i + 1 | |
| continue | |
| if csv_start is not None and s == "```": | |
| csv_end = i | |
| break | |
| if csv_start is None or csv_end is None: | |
| start = None | |
| end = None | |
| for i, ln in enumerate(lines): | |
| if ln.strip().startswith("```"): | |
| start = i + 1 | |
| break | |
| if start is not None: | |
| for j in range(start, len(lines)): | |
| if lines[j].strip() == "```": | |
| end = j | |
| break | |
| if start is not None and end is not None: | |
| csv_start, csv_end = start, end | |
| if csv_start is not None and csv_end is not None and 0 <= csv_start <= csv_end <= len(lines): | |
| csv_lines = lines[csv_start:csv_end] | |
| explanation_lines = lines[:csv_start - 1] + lines[csv_end + 1:] | |
| explanation_text = "\n".join(explanation_lines).strip() | |
| csv_text = "\n".join(csv_lines).strip() | |
| return explanation_text, csv_text | |
| return "", text.strip() | |
| def parse_metadata_file(metadata_file) -> Optional[str]: | |
| """Parse uploaded metadata file and format for LLM prompt.""" | |
| if not metadata_file: | |
| return None | |
| try: | |
| with open(metadata_file.name, 'r') as f: | |
| content = f.read().strip() | |
| # Try to parse as JSON first | |
| try: | |
| metadata = json.loads(content) | |
| if isinstance(metadata, dict): | |
| # Format as structured metadata prompt | |
| metadata_prompt = "\n\nExpected Data Schema:\n" | |
| for column, info in metadata.items(): | |
| if isinstance(info, dict): | |
| col_type = info.get('type', 'unknown') | |
| description = info.get('description', '') | |
| metadata_prompt += f"- {column}: {col_type}" | |
| if description: | |
| metadata_prompt += f" - {description}" | |
| metadata_prompt += "\n" | |
| else: | |
| metadata_prompt += f"- {column}: {info}\n" | |
| return metadata_prompt | |
| except json.JSONDecodeError: | |
| pass | |
| # If not JSON, treat as plain text metadata | |
| return f"\n\nExpected Data Schema:\n{content}\n" | |
| except Exception as e: | |
| print(f"Error parsing metadata file: {e}") | |
| return None | |
| class SyntheticDataApp: | |
| def __init__(self): | |
| self.sample_df = None | |
| self.sample_csv_bytes = None | |
| self.explanation_text = None | |
| self.scaled_csv_bytes = None | |
| def process_pdf_and_generate_sample( | |
| self, | |
| pdf_file, | |
| metadata_file, | |
| llama_key: str, | |
| openrouter_key: str, | |
| model_name: str = "google/gemini-flash-1.5", | |
| k_chunks: int = 10, | |
| progress=gr.Progress() | |
| ): | |
| """Process PDF and generate sample CSV data.""" | |
| if not pdf_file: | |
| return "❌ Please upload a PDF file first.", None, None, None | |
| if not llama_key or not openrouter_key: | |
| return "❌ Please provide both LLAMA_CLOUD_API_KEY and OPENROUTER_API_KEY.", None, None, None | |
| try: | |
| # Set API keys | |
| os.environ["LLAMA_CLOUD_API_KEY"] = llama_key | |
| os.environ["OPENROUTER_API_KEY"] = openrouter_key | |
| progress(0.1, desc="Parsing PDF...") | |
| pdf_content = pdf_parser(pdf_file) | |
| progress(0.3, desc="Generating embeddings...") | |
| embed_model = load_embed_model("all-MiniLM-L6-v2") | |
| embeddings = generate_embedding(pdf_content, embed_model) | |
| index = generate_faiss_index(embeddings) | |
| progress(0.6, desc="Generating synthetic data...") | |
| # Parse metadata if provided | |
| metadata_prompt = parse_metadata_file(metadata_file) | |
| with tempfile.TemporaryDirectory() as tmpdir: | |
| index_path = os.path.join(tmpdir, "faiss_index.index") | |
| chunks_path = os.path.join(tmpdir, "text_chunks.pkl") | |
| faiss.write_index(index, index_path) | |
| with open(chunks_path, "wb") as f: | |
| pickle.dump(pdf_content, f) | |
| generator = SyntheticDataGenerator( | |
| openai_api_key=openrouter_key, | |
| model_name=model_name, | |
| index_path=index_path, | |
| text_chunks_path=chunks_path, | |
| max_context_length=8000, | |
| metadata_context=metadata_prompt | |
| ) | |
| result = generator.generate_synthetic_data(k=int(k_chunks)) | |
| raw_response = result.get("response", "") | |
| explanation_text, csv_text = split_explanation_and_csv(raw_response) | |
| # Clean and parse CSV | |
| csv_text_clean = "\n".join([ | |
| ln for ln in csv_text.splitlines() | |
| if ln.strip() not in ('csv', 'CSV', '```', '```csv', '```CSV') | |
| ]) | |
| df_sample = pd.read_csv(StringIO(csv_text_clean), comment='`') | |
| df_sample.rename(columns=lambda c: str(c).strip(), inplace=True) | |
| if 'csv' in df_sample.columns or 'CSV' in df_sample.columns: | |
| df_sample = df_sample.drop(columns=[ | |
| c for c in ('csv', 'CSV') if c in df_sample.columns | |
| ]) | |
| sample_csv_bytes = df_sample.to_csv(index=False).encode() | |
| # Store results | |
| self.sample_df = df_sample | |
| self.sample_csv_bytes = sample_csv_bytes | |
| self.explanation_text = explanation_text | |
| progress(1.0, desc="Complete!") | |
| return ( | |
| f"✅ Successfully generated sample data with {len(df_sample)} rows and {len(df_sample.columns)} columns.", | |
| df_sample, | |
| explanation_text or "No explanation provided by the model.", | |
| sample_csv_bytes | |
| ) | |
| except Exception as e: | |
| return f"❌ Error: {str(e)}", None, None, None | |
| def scale_data( | |
| self, | |
| dbtwin_key: str, | |
| scale_rows: int = 1000, | |
| progress=gr.Progress() | |
| ): | |
| """Scale the generated sample data using DBTwin API.""" | |
| if self.sample_df is None: | |
| return "❌ Please generate sample data first.", None, None | |
| if not dbtwin_key: | |
| return "❌ Please provide DBTWIN_API_KEY.", None, None | |
| try: | |
| progress(0.2, desc="Scaling data via DBTwin API...") | |
| os.environ["DBTWIN_API_KEY"] = dbtwin_key | |
| scaled_df, scaled_bytes, headers = scale_csv_text( | |
| self.sample_df.to_csv(index=False), | |
| dbtwin_key, | |
| rows=int(scale_rows), | |
| algo="flagship", | |
| ) | |
| self.scaled_csv_bytes = scaled_bytes | |
| # Get quality metrics | |
| dist_err = headers.get("distribution-similarity-error") if headers else "N/A" | |
| assoc_sim = headers.get("association-similarity") if headers else "N/A" | |
| progress(1.0, desc="Scaling complete!") | |
| metrics_info = f"📊 **Quality Metrics:**\n- Distribution Similarity Error: {dist_err}\n- Association Similarity: {assoc_sim}" | |
| return ( | |
| f"✅ Successfully scaled data to {len(scaled_df)} rows.\n\n{metrics_info}", | |
| scaled_df, | |
| scaled_bytes | |
| ) | |
| except Exception as e: | |
| return f"❌ Scaling failed: {str(e)}", None, None | |
| def create_interface(): | |
| """Create the Gradio interface.""" | |
| app = SyntheticDataApp() | |
| with gr.Blocks( | |
| title="Healthcare Paper → Synthetic Data Generator", | |
| theme=gr.themes.Soft(), | |
| css=""" | |
| .main-header { | |
| text-align: center; | |
| background: linear-gradient(90deg, #667eea 0%, #764ba2 100%); | |
| -webkit-background-clip: text; | |
| -webkit-text-fill-color: transparent; | |
| font-size: 2.5em; | |
| font-weight: bold; | |
| margin-bottom: 1em; | |
| } | |
| .step-header { | |
| background-color: #f8f9fa; | |
| padding: 1em; | |
| border-radius: 10px; | |
| border-left: 4px solid #667eea; | |
| margin: 1em 0; | |
| } | |
| .info-box { | |
| background-color: #e3f2fd; | |
| padding: 1em; | |
| border-radius: 8px; | |
| border: 1px solid #2196f3; | |
| margin: 0.5em 0; | |
| } | |
| """ | |
| ) as demo: | |
| gr.HTML(""" | |
| <div class="main-header"> | |
| 🔬 Healthcare Paper → Synthetic Data Generator | |
| </div> | |
| <div class="info-box"> | |
| <h3>Enterprise-Ready Synthetic Data Pipeline</h3> | |
| <p>Transform research papers into high-quality synthetic datasets in three simple steps:</p> | |
| <ul> | |
| <li><b>Step 1:</b> Upload your research paper (PDF format)</li> | |
| <li><b>Step 2:</b> Generate sample synthetic data using AI</li> | |
| <li><b>Step 3:</b> Scale to production-size datasets with DBTwin</li> | |
| </ul> | |
| </div> | |
| """) | |
| # Step 1: Configuration and PDF Upload | |
| with gr.Group(): | |
| gr.HTML('<div class="step-header"><h2>📋 Step 1: Configuration & Upload</h2></div>') | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| pdf_file = gr.File( | |
| label="📄 Upload Research Paper (PDF)", | |
| file_types=[".pdf"] | |
| ) | |
| metadata_file = gr.File( | |
| label="📋 Upload Data Schema/Metadata (Optional)", | |
| file_types=[".json", ".txt", ".md"] | |
| ) | |
| gr.HTML(""" | |
| <div style="background-color: #e8f4fd; padding: 0.8em; border-radius: 6px; border: 1px solid #b3d9ff; margin-top: 0.5em;"> | |
| <h5>📋 Metadata Format Examples:</h5> | |
| <p><b>JSON format:</b></p> | |
| <pre style="font-size: 0.8em; background-color: #f8f9fa; padding: 0.5em; border-radius: 4px;"> | |
| { | |
| "age": {"type": "integer", "description": "Patient age in years"}, | |
| "gender": {"type": "categorical", "description": "Male/Female"}, | |
| "blood_pressure": {"type": "float", "description": "Systolic BP in mmHg"} | |
| }</pre> | |
| <p><b>Text format:</b> Simply describe your expected columns and their types.</p> | |
| </div> | |
| """) | |
| with gr.Column(scale=1): | |
| gr.HTML(""" | |
| <div style="background-color: #fff3cd; padding: 1em; border-radius: 8px; border: 1px solid #ffeaa7;"> | |
| <h4>💡 Tips for Best Results:</h4> | |
| <ul> | |
| <li>Upload healthcare/medical research papers</li> | |
| <li>Ensure PDF contains tables or data descriptions</li> | |
| <li>Clear text (not scanned images) works best</li> | |
| <li>Upload metadata to specify expected column types</li> | |
| </ul> | |
| </div> | |
| """) | |
| with gr.Accordion("⚙️ API Configuration", open=False): | |
| with gr.Row(): | |
| llama_key = gr.Textbox( | |
| label="🔑 LLAMA_CLOUD_API_KEY", | |
| type="password", | |
| placeholder="Enter your LlamaParse API key", | |
| info="Required for PDF parsing" | |
| ) | |
| openrouter_key = gr.Textbox( | |
| label="🔑 OPENROUTER_API_KEY", | |
| type="password", | |
| placeholder="Enter your OpenRouter API key", | |
| info="Required for LLM data generation" | |
| ) | |
| with gr.Row(): | |
| model_name = gr.Textbox( | |
| label="🤖 LLM Model", | |
| value="google/gemini-flash-1.5", | |
| info="OpenRouter model name" | |
| ) | |
| k_chunks = gr.Slider( | |
| label="📊 Context Chunks", | |
| minimum=3, | |
| maximum=30, | |
| value=10, | |
| step=1, | |
| info="Number of relevant document chunks to use" | |
| ) | |
| # Step 2: Generate Sample Data | |
| with gr.Group(): | |
| gr.HTML('<div class="step-header"><h2>🎯 Step 2: Generate Sample Data</h2></div>') | |
| generate_btn = gr.Button( | |
| "Generate Sample Synthetic Data", | |
| variant="primary", | |
| size="lg", | |
| scale=2 | |
| ) | |
| generation_status = gr.Textbox( | |
| label="Status", | |
| interactive=False, | |
| max_lines=3 | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| sample_data_preview = gr.Dataframe( | |
| label="Sample Data Preview", | |
| wrap=True | |
| ) | |
| sample_download = gr.File( | |
| label="Download Sample CSV", | |
| visible=False | |
| ) | |
| with gr.Column(): | |
| explanation_output = gr.Markdown( | |
| label="Feature Explanation", | |
| value="*Feature descriptions will appear here after generation*" | |
| ) | |
| # Step 3: Scale Data | |
| with gr.Group(): | |
| gr.HTML('<div class="step-header"><h2>Step 3: Scale Data (Optional)</h2></div>') | |
| with gr.Row(): | |
| with gr.Column(): | |
| dbtwin_key = gr.Textbox( | |
| label="🔑 DBTWIN_API_KEY", | |
| type="password", | |
| placeholder="Enter your DBTwin API key (optional)", | |
| info="Required for data scaling" | |
| ) | |
| scale_rows = gr.Slider( | |
| label="Target Dataset Size", | |
| minimum=100, | |
| maximum=200000, | |
| value=1000, | |
| step=100, | |
| info="Number of rows in scaled dataset" | |
| ) | |
| scale_btn = gr.Button( | |
| "Scale Dataset with DBTwin", | |
| variant="secondary", | |
| size="lg" | |
| ) | |
| with gr.Column(): | |
| gr.HTML(""" | |
| <div style="background-color: #d1ecf1; padding: 1em; border-radius: 8px; border: 1px solid #bee5eb;"> | |
| <h4>About Data Scaling:</h4> | |
| <ul> | |
| <li><b>DBTwin:</b> Professional synthetic data generation</li> | |
| <li><b>Maintains:</b> Statistical properties and correlations</li> | |
| <li><b>Quality:</b> Enterprise-grade data quality metrics</li> | |
| <li><b>Scale:</b> Up to 200,000 rows</li> | |
| </ul> | |
| </div> | |
| """) | |
| scaling_status = gr.Textbox( | |
| label="Scaling Status", | |
| interactive=False, | |
| max_lines=4 | |
| ) | |
| with gr.Row(): | |
| scaled_data_preview = gr.Dataframe( | |
| label="Scaled Data Preview", | |
| wrap=True | |
| ) | |
| scaled_download = gr.File( | |
| label="Download Scaled CSV", | |
| visible=False | |
| ) | |
| # Footer | |
| gr.HTML(""" | |
| <div style="text-align: center; margin-top: 2em; padding: 1em; background-color: #f8f9fa; border-radius: 10px;"> | |
| <p><strong>Enterprise Synthetic Data Pipeline</strong> | | |
| Built for Healthcare Research & Data Science Teams</p> | |
| <p>Need help? Contact your data engineering team or check the documentation.</p> | |
| </div> | |
| """) | |
| # Event handlers | |
| generate_btn.click( | |
| fn=app.process_pdf_and_generate_sample, | |
| inputs=[pdf_file, metadata_file, llama_key, openrouter_key, model_name, k_chunks], | |
| outputs=[generation_status, sample_data_preview, explanation_output, sample_download], | |
| show_progress=True | |
| ).then( | |
| fn=lambda x: gr.File(value=x, visible=True) if x else gr.File(visible=False), | |
| inputs=[sample_download], | |
| outputs=[sample_download] | |
| ) | |
| scale_btn.click( | |
| fn=app.scale_data, | |
| inputs=[dbtwin_key, scale_rows], | |
| outputs=[scaling_status, scaled_data_preview, scaled_download], | |
| show_progress=True | |
| ).then( | |
| fn=lambda x: gr.File(value=x, visible=True) if x else gr.File(visible=False), | |
| inputs=[scaled_download], | |
| outputs=[scaled_download] | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = create_interface() | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=True, | |
| show_error=True | |
| ) |