Unknown-Geek
Added gemini integration
a773052
import os
import sys
import logging
# Set up logging first
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
try:
# Check versions before importing other libraries
import numpy as np
import pandas as pd
logger.info(f"NumPy version: {np.__version__}")
logger.info(f"Pandas version: {pd.__version__}")
try:
import torch
logger.info(f"PyTorch version: {torch.__version__}")
except ImportError as e:
logger.warning(f"PyTorch import error: {str(e)}")
from fastapi import FastAPI, UploadFile, File, Form
from fastapi.responses import FileResponse, JSONResponse
from fastapi.middleware.cors import CORSMiddleware
import traceback
import shutil
import tempfile
import uvicorn # Import uvicorn for running the server
try:
from synthetic_data_pipeline import SyntheticDataPipeline
from gemini_generator import GeminiSyntheticGenerator, GenerationConfig, generate_synthetic_data_enhanced
import time
except ImportError as e:
logger.error(f"Failed to import SyntheticDataPipeline: {str(e)}")
raise
except ImportError as e:
logger.error(f"Import error: {str(e)}")
logger.error("Try checking package compatibility or downgrading packages")
sys.exit(1)
from typing import Optional
app = FastAPI(title="Synthetic Data Generator")
# Setup CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Configuration
UPLOAD_FOLDER = 'temp_uploads'
OUTPUT_FOLDER = 'output'
# Create necessary directories
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
os.makedirs(OUTPUT_FOLDER, exist_ok=True)
def cleanup_output_directory(directory):
"""Remove all files in the output directory except .gitkeep"""
for filename in os.listdir(directory):
if filename != '.gitkeep':
file_path = os.path.join(directory, filename)
try:
if os.path.isfile(file_path):
# Add retry logic for locked files
max_retries = 3
for _ in range(max_retries):
try:
os.unlink(file_path)
break
except PermissionError:
time.sleep(1) # Wait before retry
elif os.path.isdir(file_path):
shutil.rmtree(file_path, ignore_errors=True)
except Exception as e:
logger.error(f'Error deleting {file_path}: {e}')
@app.get("/")
def root():
return {"message": "Synthetic Data Generator API", "numpy_version": np.__version__}
@app.get("/health")
def health_check():
try:
return {"status": "healthy", "numpy_version": np.__version__}
except Exception as e:
logger.error(f"Health check failed: {str(e)}")
return JSONResponse(
status_code=500,
content={"status": "unhealthy", "error": str(e)}
)
@app.post("/generate")
async def generate_synthetic_data(
file: UploadFile = File(...),
categorical_columns: str = Form(...),
num_samples: int = Form(1000)
):
try:
# Clean up output directory before starting
cleanup_output_directory(OUTPUT_FOLDER)
logger.info("Cleaned up output directory")
# Save uploaded file to temp location
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.csv')
try:
contents = await file.read()
temp_file.write(contents)
temp_file.close()
filepath = temp_file.name
logger.info(f"File saved to {filepath}")
# Parse categorical columns
categorical_columns_list = [
col.strip()
for col in categorical_columns.replace('"', '').replace("'", '').split(',')
if col.strip()
]
if not categorical_columns_list:
return JSONResponse(
status_code=400,
content={"error": "No valid categorical columns provided"}
)
logger.info(f"Processing with categorical columns: {categorical_columns_list}")
logger.info(f"Number of samples requested: {num_samples}")
# Run pipeline with kwargs
pipeline = SyntheticDataPipeline(
input_file=filepath,
categorical_columns=categorical_columns_list,
output_dir=os.path.abspath(OUTPUT_FOLDER)
)
pipeline.run_pipeline(
num_samples=num_samples,
epochs=100,
chunk_size=10000
)
# Get latest generated file
output_dir = pipeline.output_dir
files = [f for f in os.listdir(output_dir) if f.startswith("synthetic_data_")]
if not files:
raise Exception("No output file generated")
latest_file = sorted(files)[-1]
output_path = os.path.join(output_dir, latest_file)
logger.info(f"Sending file: {output_path}")
return FileResponse(
path=output_path,
filename="synthetic_data.csv",
media_type="text/csv"
)
finally:
# Clean up temp file
if os.path.exists(filepath):
os.unlink(filepath)
except Exception as e:
logger.error(f"Error in generate_synthetic_data: {str(e)}")
logger.error(traceback.format_exc())
return JSONResponse(
status_code=500,
content={"error": str(e)}
)
@app.post("/generate/enhanced")
async def generate_enhanced_synthetic_data(
file: UploadFile = File(...),
categorical_columns: str = Form(...),
num_samples: int = Form(1000),
use_gemini: bool = Form(True)
):
"""
Generate synthetic data using hybrid CTGAN + Gemini approach.
This endpoint provides higher quality synthetic data by combining:
- CTGAN for statistical structure and column correlations
- Gemini API for realistic, contextually appropriate values
If Gemini API key is not configured, falls back to CTGAN-only.
"""
try:
cleanup_output_directory(OUTPUT_FOLDER)
logger.info("Starting enhanced synthetic data generation")
# Save uploaded file
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.csv')
try:
contents = await file.read()
temp_file.write(contents)
temp_file.close()
filepath = temp_file.name
# Parse categorical columns
categorical_columns_list = [
col.strip()
for col in categorical_columns.replace('"', '').replace("'", '').split(',')
if col.strip()
]
if not categorical_columns_list:
return JSONResponse(
status_code=400,
content={"error": "No valid categorical columns provided"}
)
logger.info(f"Enhanced generation with columns: {categorical_columns_list}")
logger.info(f"Samples requested: {num_samples}, Use Gemini: {use_gemini}")
# Generate using hybrid approach
synthetic_data = generate_synthetic_data_enhanced(
input_file=filepath,
categorical_columns=categorical_columns_list,
num_samples=num_samples,
use_gemini=use_gemini,
epochs=100
)
# Save output
from datetime import datetime
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_path = os.path.join(OUTPUT_FOLDER, f"synthetic_enhanced_{timestamp}.csv")
synthetic_data.to_csv(output_path, index=False)
logger.info(f"Enhanced synthetic data saved to: {output_path}")
return FileResponse(
path=output_path,
filename="synthetic_data_enhanced.csv",
media_type="text/csv"
)
finally:
if os.path.exists(filepath):
os.unlink(filepath)
except Exception as e:
logger.error(f"Error in enhanced generation: {str(e)}")
logger.error(traceback.format_exc())
return JSONResponse(
status_code=500,
content={"error": str(e)}
)
@app.get("/gemini/status")
async def gemini_status():
"""Check if Gemini API is configured and available."""
generator = GeminiSyntheticGenerator()
return {
"available": generator.is_available,
"message": "Gemini API ready" if generator.is_available else "Gemini API key not configured"
}
# Add this section to run the server when the script is executed directly
if __name__ == "__main__":
# Get port from environment variable or use default
port = int(os.environ.get("PORT", 7860))
logger.info(f"Starting FastAPI server on port {port}")
# Check if we should use the new startup script
if os.path.exists("start_server.py"):
logger.info("Using start_server.py for optimal server configuration")
# For legacy compatibility, default to development mode when run directly
os.environ["ENVIRONMENT"] = os.environ.get("ENVIRONMENT", "development")
import importlib.util
spec = importlib.util.spec_from_file_location("start_server", "start_server.py")
start_server = importlib.util.module_from_spec(spec)
spec.loader.exec_module(start_server)
else:
# Fallback to direct uvicorn for backward compatibility
logger.info("Using direct uvicorn server (legacy mode)")
import uvicorn
uvicorn.run("app:app", host="0.0.0.0", port=port, reload=True)