| import os |
| import sys |
| import logging |
|
|
| |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| try: |
| |
| import numpy as np |
| import pandas as pd |
| logger.info(f"NumPy version: {np.__version__}") |
| logger.info(f"Pandas version: {pd.__version__}") |
| |
| try: |
| import torch |
| logger.info(f"PyTorch version: {torch.__version__}") |
| except ImportError as e: |
| logger.warning(f"PyTorch import error: {str(e)}") |
| |
| from fastapi import FastAPI, UploadFile, File, Form |
| from fastapi.responses import FileResponse, JSONResponse |
| from fastapi.middleware.cors import CORSMiddleware |
| import traceback |
| import shutil |
| import tempfile |
| import uvicorn |
| |
| try: |
| from synthetic_data_pipeline import SyntheticDataPipeline |
| from gemini_generator import GeminiSyntheticGenerator, GenerationConfig, generate_synthetic_data_enhanced |
| import time |
| except ImportError as e: |
| logger.error(f"Failed to import SyntheticDataPipeline: {str(e)}") |
| raise |
| except ImportError as e: |
| logger.error(f"Import error: {str(e)}") |
| logger.error("Try checking package compatibility or downgrading packages") |
| sys.exit(1) |
|
|
| from typing import Optional |
|
|
| app = FastAPI(title="Synthetic Data Generator") |
|
|
| |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| |
| UPLOAD_FOLDER = 'temp_uploads' |
| OUTPUT_FOLDER = 'output' |
|
|
| |
| os.makedirs(UPLOAD_FOLDER, exist_ok=True) |
| os.makedirs(OUTPUT_FOLDER, exist_ok=True) |
|
|
| def cleanup_output_directory(directory): |
| """Remove all files in the output directory except .gitkeep""" |
| for filename in os.listdir(directory): |
| if filename != '.gitkeep': |
| file_path = os.path.join(directory, filename) |
| try: |
| if os.path.isfile(file_path): |
| |
| max_retries = 3 |
| for _ in range(max_retries): |
| try: |
| os.unlink(file_path) |
| break |
| except PermissionError: |
| time.sleep(1) |
| elif os.path.isdir(file_path): |
| shutil.rmtree(file_path, ignore_errors=True) |
| except Exception as e: |
| logger.error(f'Error deleting {file_path}: {e}') |
|
|
| @app.get("/") |
| def root(): |
| return {"message": "Synthetic Data Generator API", "numpy_version": np.__version__} |
|
|
| @app.get("/health") |
| def health_check(): |
| try: |
| return {"status": "healthy", "numpy_version": np.__version__} |
| except Exception as e: |
| logger.error(f"Health check failed: {str(e)}") |
| return JSONResponse( |
| status_code=500, |
| content={"status": "unhealthy", "error": str(e)} |
| ) |
|
|
| @app.post("/generate") |
| async def generate_synthetic_data( |
| file: UploadFile = File(...), |
| categorical_columns: str = Form(...), |
| num_samples: int = Form(1000) |
| ): |
| try: |
| |
| cleanup_output_directory(OUTPUT_FOLDER) |
| logger.info("Cleaned up output directory") |
|
|
| |
| temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.csv') |
| try: |
| contents = await file.read() |
| temp_file.write(contents) |
| temp_file.close() |
| filepath = temp_file.name |
| |
| logger.info(f"File saved to {filepath}") |
| |
| |
| categorical_columns_list = [ |
| col.strip() |
| for col in categorical_columns.replace('"', '').replace("'", '').split(',') |
| if col.strip() |
| ] |
| |
| if not categorical_columns_list: |
| return JSONResponse( |
| status_code=400, |
| content={"error": "No valid categorical columns provided"} |
| ) |
| |
| logger.info(f"Processing with categorical columns: {categorical_columns_list}") |
| logger.info(f"Number of samples requested: {num_samples}") |
| |
| |
| pipeline = SyntheticDataPipeline( |
| input_file=filepath, |
| categorical_columns=categorical_columns_list, |
| output_dir=os.path.abspath(OUTPUT_FOLDER) |
| ) |
| |
| pipeline.run_pipeline( |
| num_samples=num_samples, |
| epochs=100, |
| chunk_size=10000 |
| ) |
| |
| |
| output_dir = pipeline.output_dir |
| files = [f for f in os.listdir(output_dir) if f.startswith("synthetic_data_")] |
| |
| if not files: |
| raise Exception("No output file generated") |
| |
| latest_file = sorted(files)[-1] |
| output_path = os.path.join(output_dir, latest_file) |
| |
| logger.info(f"Sending file: {output_path}") |
| |
| return FileResponse( |
| path=output_path, |
| filename="synthetic_data.csv", |
| media_type="text/csv" |
| ) |
| finally: |
| |
| if os.path.exists(filepath): |
| os.unlink(filepath) |
|
|
| except Exception as e: |
| logger.error(f"Error in generate_synthetic_data: {str(e)}") |
| logger.error(traceback.format_exc()) |
| return JSONResponse( |
| status_code=500, |
| content={"error": str(e)} |
| ) |
|
|
| @app.post("/generate/enhanced") |
| async def generate_enhanced_synthetic_data( |
| file: UploadFile = File(...), |
| categorical_columns: str = Form(...), |
| num_samples: int = Form(1000), |
| use_gemini: bool = Form(True) |
| ): |
| """ |
| Generate synthetic data using hybrid CTGAN + Gemini approach. |
| |
| This endpoint provides higher quality synthetic data by combining: |
| - CTGAN for statistical structure and column correlations |
| - Gemini API for realistic, contextually appropriate values |
| |
| If Gemini API key is not configured, falls back to CTGAN-only. |
| """ |
| try: |
| cleanup_output_directory(OUTPUT_FOLDER) |
| logger.info("Starting enhanced synthetic data generation") |
| |
| |
| temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.csv') |
| try: |
| contents = await file.read() |
| temp_file.write(contents) |
| temp_file.close() |
| filepath = temp_file.name |
| |
| |
| categorical_columns_list = [ |
| col.strip() |
| for col in categorical_columns.replace('"', '').replace("'", '').split(',') |
| if col.strip() |
| ] |
| |
| if not categorical_columns_list: |
| return JSONResponse( |
| status_code=400, |
| content={"error": "No valid categorical columns provided"} |
| ) |
| |
| logger.info(f"Enhanced generation with columns: {categorical_columns_list}") |
| logger.info(f"Samples requested: {num_samples}, Use Gemini: {use_gemini}") |
| |
| |
| synthetic_data = generate_synthetic_data_enhanced( |
| input_file=filepath, |
| categorical_columns=categorical_columns_list, |
| num_samples=num_samples, |
| use_gemini=use_gemini, |
| epochs=100 |
| ) |
| |
| |
| from datetime import datetime |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| output_path = os.path.join(OUTPUT_FOLDER, f"synthetic_enhanced_{timestamp}.csv") |
| synthetic_data.to_csv(output_path, index=False) |
| |
| logger.info(f"Enhanced synthetic data saved to: {output_path}") |
| |
| return FileResponse( |
| path=output_path, |
| filename="synthetic_data_enhanced.csv", |
| media_type="text/csv" |
| ) |
| finally: |
| if os.path.exists(filepath): |
| os.unlink(filepath) |
| |
| except Exception as e: |
| logger.error(f"Error in enhanced generation: {str(e)}") |
| logger.error(traceback.format_exc()) |
| return JSONResponse( |
| status_code=500, |
| content={"error": str(e)} |
| ) |
|
|
| @app.get("/gemini/status") |
| async def gemini_status(): |
| """Check if Gemini API is configured and available.""" |
| generator = GeminiSyntheticGenerator() |
| return { |
| "available": generator.is_available, |
| "message": "Gemini API ready" if generator.is_available else "Gemini API key not configured" |
| } |
|
|
| |
| if __name__ == "__main__": |
| |
| port = int(os.environ.get("PORT", 7860)) |
| logger.info(f"Starting FastAPI server on port {port}") |
| |
| |
| if os.path.exists("start_server.py"): |
| logger.info("Using start_server.py for optimal server configuration") |
| |
| os.environ["ENVIRONMENT"] = os.environ.get("ENVIRONMENT", "development") |
| import importlib.util |
| spec = importlib.util.spec_from_file_location("start_server", "start_server.py") |
| start_server = importlib.util.module_from_spec(spec) |
| spec.loader.exec_module(start_server) |
| else: |
| |
| logger.info("Using direct uvicorn server (legacy mode)") |
| import uvicorn |
| uvicorn.run("app:app", host="0.0.0.0", port=port, reload=True) |
|
|