Update app.py
Browse files
app.py
CHANGED
|
@@ -87,63 +87,59 @@ generation_params = {
|
|
| 87 |
def generate_synthetic_data(description, columns):
|
| 88 |
formatted_prompt = format_prompt(description, columns)
|
| 89 |
payload = {"inputs": formatted_prompt, "parameters": generation_params}
|
| 90 |
-
response = requests.post(API_URL, headers={"Authorization": f"Bearer {
|
| 91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
|
| 93 |
if 'error' in response_data:
|
| 94 |
-
|
|
|
|
|
|
|
|
|
|
| 95 |
|
| 96 |
return response_data[0]["generated_text"]
|
| 97 |
|
| 98 |
def process_generated_data(csv_data, expected_columns):
|
| 99 |
try:
|
| 100 |
-
# Ensure the data is cleaned and correctly formatted
|
| 101 |
cleaned_data = csv_data.replace('\r\n', '\n').replace('\r', '\n')
|
| 102 |
data = StringIO(cleaned_data)
|
| 103 |
-
|
| 104 |
-
# Read the CSV data
|
| 105 |
df = pd.read_csv(data, delimiter=',')
|
| 106 |
|
| 107 |
-
# Check if the DataFrame has the expected columns
|
| 108 |
if set(df.columns) != set(expected_columns):
|
| 109 |
-
|
| 110 |
|
| 111 |
return df
|
| 112 |
except pd.errors.ParserError as e:
|
| 113 |
-
|
| 114 |
|
| 115 |
def generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_generation=100):
|
| 116 |
-
csv_data_all =
|
| 117 |
|
| 118 |
for _ in tqdm(range(num_rows // rows_per_generation), desc="Generating Data"):
|
| 119 |
generated_data = generate_synthetic_data(description, columns)
|
| 120 |
-
if "Error" in generated_data:
|
| 121 |
-
return generated_data # Return the error message
|
| 122 |
-
|
| 123 |
df_synthetic = process_generated_data(generated_data, columns)
|
|
|
|
| 124 |
if isinstance(df_synthetic, pd.DataFrame) and not df_synthetic.empty:
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
if csv_data_all:
|
| 130 |
return csv_data_all
|
| 131 |
else:
|
| 132 |
-
|
| 133 |
|
| 134 |
@app.post("/generate/")
|
| 135 |
def generate_data(request: DataGenerationRequest):
|
| 136 |
description = request.description.strip()
|
| 137 |
columns = [col.strip() for col in request.columns]
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
if isinstance(generated_data, str) and "Error" in generated_data:
|
| 141 |
-
return JSONResponse(content={"error": generated_data}, status_code=500)
|
| 142 |
|
| 143 |
-
#
|
| 144 |
-
csv_buffer = StringIO(generated_data)
|
| 145 |
return StreamingResponse(
|
| 146 |
-
|
| 147 |
media_type="text/csv",
|
| 148 |
headers={"Content-Disposition": "attachment; filename=generated_data.csv"}
|
| 149 |
)
|
|
|
|
| 87 |
def generate_synthetic_data(description, columns):
|
| 88 |
formatted_prompt = format_prompt(description, columns)
|
| 89 |
payload = {"inputs": formatted_prompt, "parameters": generation_params}
|
| 90 |
+
response = requests.post(API_URL, headers={"Authorization": f"Bearer {hf_token}"}, json=payload)
|
| 91 |
+
|
| 92 |
+
try:
|
| 93 |
+
response_data = response.json()
|
| 94 |
+
except ValueError:
|
| 95 |
+
raise HTTPException(status_code=500, detail="Failed to parse response from the API.")
|
| 96 |
|
| 97 |
if 'error' in response_data:
|
| 98 |
+
raise HTTPException(status_code=500, detail=f"API Error: {response_data['error']}")
|
| 99 |
+
|
| 100 |
+
if 'generated_text' not in response_data[0]:
|
| 101 |
+
raise HTTPException(status_code=500, detail="Unexpected API response format.")
|
| 102 |
|
| 103 |
return response_data[0]["generated_text"]
|
| 104 |
|
| 105 |
def process_generated_data(csv_data, expected_columns):
|
| 106 |
try:
|
|
|
|
| 107 |
cleaned_data = csv_data.replace('\r\n', '\n').replace('\r', '\n')
|
| 108 |
data = StringIO(cleaned_data)
|
|
|
|
|
|
|
| 109 |
df = pd.read_csv(data, delimiter=',')
|
| 110 |
|
|
|
|
| 111 |
if set(df.columns) != set(expected_columns):
|
| 112 |
+
raise ValueError("Unexpected columns in the generated data.")
|
| 113 |
|
| 114 |
return df
|
| 115 |
except pd.errors.ParserError as e:
|
| 116 |
+
raise HTTPException(status_code=500, detail=f"Failed to parse CSV data: {e}")
|
| 117 |
|
| 118 |
def generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_generation=100):
|
| 119 |
+
csv_data_all = StringIO()
|
| 120 |
|
| 121 |
for _ in tqdm(range(num_rows // rows_per_generation), desc="Generating Data"):
|
| 122 |
generated_data = generate_synthetic_data(description, columns)
|
|
|
|
|
|
|
|
|
|
| 123 |
df_synthetic = process_generated_data(generated_data, columns)
|
| 124 |
+
|
| 125 |
if isinstance(df_synthetic, pd.DataFrame) and not df_synthetic.empty:
|
| 126 |
+
df_synthetic.to_csv(csv_data_all, index=False, header=False)
|
| 127 |
+
|
| 128 |
+
if csv_data_all.tell() > 0: # Check if there's any data in the buffer
|
| 129 |
+
csv_data_all.seek(0) # Rewind the buffer to the beginning
|
|
|
|
| 130 |
return csv_data_all
|
| 131 |
else:
|
| 132 |
+
raise HTTPException(status_code=500, detail="No valid data frames generated.")
|
| 133 |
|
| 134 |
@app.post("/generate/")
|
| 135 |
def generate_data(request: DataGenerationRequest):
|
| 136 |
description = request.description.strip()
|
| 137 |
columns = [col.strip() for col in request.columns]
|
| 138 |
+
csv_data = generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_generation=100)
|
|
|
|
|
|
|
|
|
|
| 139 |
|
| 140 |
+
# Return the CSV data as a downloadable file
|
|
|
|
| 141 |
return StreamingResponse(
|
| 142 |
+
csv_data,
|
| 143 |
media_type="text/csv",
|
| 144 |
headers={"Content-Disposition": "attachment; filename=generated_data.csv"}
|
| 145 |
)
|