Update app.py
Browse files
app.py
CHANGED
|
@@ -4,7 +4,7 @@ from pydantic import BaseModel
|
|
| 4 |
import pandas as pd
|
| 5 |
import os
|
| 6 |
import requests
|
| 7 |
-
from transformers import GPT2LMHeadModel, GPT2Tokenizer, AutoTokenizer
|
| 8 |
from io import StringIO
|
| 9 |
from fastapi.middleware.cors import CORSMiddleware
|
| 10 |
from huggingface_hub import HfFolder
|
|
@@ -95,51 +95,36 @@ def generate_synthetic_data(description, columns):
|
|
| 95 |
|
| 96 |
return response_data[0]["generated_text"]
|
| 97 |
|
| 98 |
-
def extract_valid_csv(csv_data, expected_columns):
|
| 99 |
-
lines = csv_data.split('\n')
|
| 100 |
-
valid_lines = []
|
| 101 |
-
header_found = False
|
| 102 |
-
|
| 103 |
-
for line in lines:
|
| 104 |
-
if header_found:
|
| 105 |
-
if line.strip() == '':
|
| 106 |
-
continue
|
| 107 |
-
valid_lines.append(line)
|
| 108 |
-
elif set(line.split(',')) == set(expected_columns):
|
| 109 |
-
header_found = True
|
| 110 |
-
valid_lines.append(line)
|
| 111 |
-
|
| 112 |
-
valid_csv_data = '\n'.join(valid_lines)
|
| 113 |
-
return valid_csv_data
|
| 114 |
-
|
| 115 |
def process_generated_data(csv_data, expected_columns):
|
| 116 |
try:
|
| 117 |
-
|
| 118 |
-
|
|
|
|
| 119 |
|
|
|
|
| 120 |
df = pd.read_csv(data, delimiter=',')
|
| 121 |
|
|
|
|
| 122 |
if set(df.columns) != set(expected_columns):
|
| 123 |
return f"Unexpected columns in the generated data: {df.columns}"
|
| 124 |
|
| 125 |
return df
|
| 126 |
except pd.errors.ParserError as e:
|
| 127 |
-
logging.error(f"Failed to parse CSV data: {e}")
|
| 128 |
return f"Failed to parse CSV data: {e}"
|
| 129 |
|
| 130 |
def generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_generation=100):
|
| 131 |
csv_data_all = ""
|
| 132 |
|
| 133 |
-
for _ in range(num_rows // rows_per_generation):
|
| 134 |
generated_data = generate_synthetic_data(description, columns)
|
| 135 |
if "Error" in generated_data:
|
| 136 |
-
return generated_data
|
| 137 |
|
| 138 |
df_synthetic = process_generated_data(generated_data, columns)
|
| 139 |
if isinstance(df_synthetic, pd.DataFrame) and not df_synthetic.empty:
|
| 140 |
csv_data_all += df_synthetic.to_csv(index=False, header=False)
|
| 141 |
else:
|
| 142 |
-
|
| 143 |
|
| 144 |
if csv_data_all:
|
| 145 |
return csv_data_all
|
|
@@ -155,6 +140,7 @@ def generate_data(request: DataGenerationRequest):
|
|
| 155 |
if isinstance(generated_data, str) and "Error" in generated_data:
|
| 156 |
return JSONResponse(content={"error": generated_data}, status_code=500)
|
| 157 |
|
|
|
|
| 158 |
csv_buffer = StringIO(generated_data)
|
| 159 |
return StreamingResponse(
|
| 160 |
csv_buffer,
|
|
@@ -162,7 +148,6 @@ def generate_data(request: DataGenerationRequest):
|
|
| 162 |
headers={"Content-Disposition": "attachment; filename=generated_data.csv"}
|
| 163 |
)
|
| 164 |
|
| 165 |
-
|
| 166 |
@app.get("/")
|
| 167 |
def greet_json():
|
| 168 |
return {"Hello": "World!"}
|
|
|
|
| 4 |
import pandas as pd
|
| 5 |
import os
|
| 6 |
import requests
|
| 7 |
+
from transformers import GPT2LMHeadModel, GPT2Tokenizer, AutoTokenizer
|
| 8 |
from io import StringIO
|
| 9 |
from fastapi.middleware.cors import CORSMiddleware
|
| 10 |
from huggingface_hub import HfFolder
|
|
|
|
| 95 |
|
| 96 |
return response_data[0]["generated_text"]
|
| 97 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
def process_generated_data(csv_data, expected_columns):
|
| 99 |
try:
|
| 100 |
+
# Ensure the data is cleaned and correctly formatted
|
| 101 |
+
cleaned_data = csv_data.replace('\r\n', '\n').replace('\r', '\n')
|
| 102 |
+
data = StringIO(cleaned_data)
|
| 103 |
|
| 104 |
+
# Read the CSV data
|
| 105 |
df = pd.read_csv(data, delimiter=',')
|
| 106 |
|
| 107 |
+
# Check if the DataFrame has the expected columns
|
| 108 |
if set(df.columns) != set(expected_columns):
|
| 109 |
return f"Unexpected columns in the generated data: {df.columns}"
|
| 110 |
|
| 111 |
return df
|
| 112 |
except pd.errors.ParserError as e:
|
|
|
|
| 113 |
return f"Failed to parse CSV data: {e}"
|
| 114 |
|
| 115 |
def generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_generation=100):
|
| 116 |
csv_data_all = ""
|
| 117 |
|
| 118 |
+
for _ in tqdm(range(num_rows // rows_per_generation), desc="Generating Data"):
|
| 119 |
generated_data = generate_synthetic_data(description, columns)
|
| 120 |
if "Error" in generated_data:
|
| 121 |
+
return generated_data # Return the error message
|
| 122 |
|
| 123 |
df_synthetic = process_generated_data(generated_data, columns)
|
| 124 |
if isinstance(df_synthetic, pd.DataFrame) and not df_synthetic.empty:
|
| 125 |
csv_data_all += df_synthetic.to_csv(index=False, header=False)
|
| 126 |
else:
|
| 127 |
+
print("Skipping invalid generation.")
|
| 128 |
|
| 129 |
if csv_data_all:
|
| 130 |
return csv_data_all
|
|
|
|
| 140 |
if isinstance(generated_data, str) and "Error" in generated_data:
|
| 141 |
return JSONResponse(content={"error": generated_data}, status_code=500)
|
| 142 |
|
| 143 |
+
# Create a streaming response to return the CSV data
|
| 144 |
csv_buffer = StringIO(generated_data)
|
| 145 |
return StreamingResponse(
|
| 146 |
csv_buffer,
|
|
|
|
| 148 |
headers={"Content-Disposition": "attachment; filename=generated_data.csv"}
|
| 149 |
)
|
| 150 |
|
|
|
|
| 151 |
@app.get("/")
|
| 152 |
def greet_json():
|
| 153 |
return {"Hello": "World!"}
|