|
|
import os |
|
|
import io |
|
|
import streamlit as st |
|
|
import google.generativeai as genai |
|
|
import pandas as pd |
|
|
from streamlit_extras.colored_header import colored_header |
|
|
from streamlit_extras.add_vertical_space import add_vertical_space |
|
|
|
|
|
|
|
|
st.set_page_config(layout="wide", page_title="AI 학습 데이터 만들기") |
|
|
|
|
|
|
|
|
try: |
|
|
genai.configure(api_key=st.secrets["GEMINI_API_KEY"]) |
|
|
except (KeyError, AttributeError): |
|
|
st.error("🚨 GEMINI_API_KEY를 설정해주세요! (Streamlit secrets에 추가)") |
|
|
st.stop() |
|
|
|
|
|
|
|
|
generation_config = { |
|
|
"temperature": 0.8, |
|
|
"top_p": 0.95, |
|
|
"top_k": 40, |
|
|
"max_output_tokens": 8192, |
|
|
"response_mime_type": "text/plain", |
|
|
} |
|
|
model = genai.GenerativeModel( |
|
|
model_name="gemini-1.5-flash", |
|
|
generation_config=generation_config, |
|
|
) |
|
|
|
|
|
|
|
|
MULTI_VARIABLE_PROMPT = """ |
|
|
당신은 현실적인 다변량 데이터를 시뮬레이션하는 데이터 과학자 AI입니다. |
|
|
당신의 임무는 사용자가 제공한 여러 원인(X) 변수와 결과(Y) 변수들의 현실적인 관계를 고려하여, 머신러닝 회귀 분석 학습에 적합한 데이터를 **Markdown 테이블 형식**으로 생성하는 것입니다. |
|
|
|
|
|
**사용자 입력:** |
|
|
* 원인 (X 변수) 목록: "{x_names_str}" |
|
|
* 결과 (Y 변수) 목록: "{y_names_str}" |
|
|
* 생성할 데이터 개수: {num_rows} |
|
|
|
|
|
**수행할 작업 (매우 중요):** |
|
|
1. **현실적 다변량 관계 모델링:** |
|
|
* 제공된 변수들 간의 **현실적인 상관관계**를 모델링합니다. 일부 X는 Y에 긍정적인 영향을, 다른 X는 부정적인 영향을 줄 수 있습니다. |
|
|
* X 변수들 사이에도 자연스러운 상관관계가 존재할 수 있습니다. (예: '운동 시간'이 늘면 '수면의 질'도 좋아지는 경향) |
|
|
2. **현실적 변동성 추가:** 관계가 완벽한 수학 공식이 아닌, 현실 데이터처럼 보이도록 적절한 무작위 변동성을 추가합니다. |
|
|
3. **현실적 제약 조건 적용:** 각 변수의 의미를 고려하여 상한선(Maximum)과 하한선(Minimum)을 자연스럽게 적용합니다. |
|
|
* 예를 들어, 변수 이름에 '점수', '만족도', '비율'이 포함되면 **결과값이 100을 넘지 않고, 0 미만이 되지 않도록** 데이터를 생성합니다. |
|
|
* 변수 이름에 '시간', '비용', '노력' 등이 포함되면, 이 값이 커질수록 결과값의 상승폭이 점차 둔화되는 **'수확 체감(diminishing returns)' 현상**을 현실적으로 반영합니다. |
|
|
4. **출력 형식 준수:** 결과는 **오직 Markdown 테이블 형식**으로만 출력합니다. |
|
|
|
|
|
**출력 형식 (절대 변경 금지):** |
|
|
* 첫 줄은 헤더 `{header_line}` 입니다. |
|
|
* 두 번째 줄은 구분선 `{separator_line}` 입니다. |
|
|
* 그 이후로는 `| 값1 | 값2 | ... |` 형식의 데이터 행을 {num_rows}개 만큼 생성합니다. |
|
|
* 설명, 코드, ``` 등 다른 어떤 텍스트도 포함하지 마세요. |
|
|
|
|
|
**현실적인 데이터 출력 예시 (규칙을 잘 따르는 예시):** |
|
|
| 공부 시간 | 수면 시간 | 시험 점수 | 컨디션 점수 | |
|
|
|---|---|---|---| |
|
|
| 1.5 | 8.2 | 65.7 | 88.1 | |
|
|
| 4.0 | 6.5 | 88.2 | 72.4 | |
|
|
| 0.5 | 7.5 | 42.1 | 81.0 | |
|
|
| 3.2 | 8.0 | 85.9 | 92.5 | |
|
|
| 5.5 | 5.8 | 91.5 | 65.3 | |
|
|
|
|
|
이제 아래 정보를 바탕으로 현실적인 제약 조건을 따른 Markdown 테이블을 생성해주세요. |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
def generate_markdown_data(x_names, y_names, num_rows): |
|
|
"""AI를 호출하여 다변량 Markdown 테이블 형식의 데이터를 받아오는 함수""" |
|
|
all_vars = x_names + y_names |
|
|
header_line = "| " + " | ".join(all_vars) + " |" |
|
|
separator_line = "|---" * len(all_vars) + "|" |
|
|
|
|
|
prompt = MULTI_VARIABLE_PROMPT.format( |
|
|
x_names_str=", ".join(x_names), |
|
|
y_names_str=", ".join(y_names), |
|
|
num_rows=num_rows, |
|
|
header_line=header_line, |
|
|
separator_line=separator_line |
|
|
) |
|
|
try: |
|
|
response = model.generate_content([prompt]) |
|
|
return response.text.strip() |
|
|
except Exception as e: |
|
|
st.error(f"AI 호출 중 오류가 발생했습니다: {e}") |
|
|
return None |
|
|
|
|
|
def parse_markdown_to_df(markdown_text): |
|
|
"""Markdown 테이블 텍스트를 Pandas DataFrame으로 변환하는 함수""" |
|
|
try: |
|
|
|
|
|
|
|
|
md_file = io.StringIO(markdown_text) |
|
|
df = pd.read_csv(md_file, sep='|', skipinitialspace=True) |
|
|
|
|
|
|
|
|
df = df.iloc[:, 1:-1] |
|
|
|
|
|
|
|
|
df.columns = df.columns.str.strip() |
|
|
|
|
|
|
|
|
for col in df.columns: |
|
|
|
|
|
df[col] = df[col].iloc[1:] |
|
|
df[col] = pd.to_numeric(df[col].str.strip(), errors='coerce') |
|
|
|
|
|
df.dropna(inplace=True) |
|
|
df.reset_index(drop=True, inplace=True) |
|
|
return df |
|
|
|
|
|
except Exception as e: |
|
|
st.error(f"생성된 데이터를 테이블로 변환하는 중 오류가 발생했습니다: {e}") |
|
|
st.info("AI가 생성한 원본 데이터:") |
|
|
st.text(markdown_text) |
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
colored_header( |
|
|
label="🤖 나만의 AI 학습 데이터 만들기 (다변량 회귀용)", |
|
|
description="AI가 여러 변수들의 복합적인 관계를 고려하여 '진짜 같은' 학습 데이터를 만들어줘요!", |
|
|
color_name="blue-70", |
|
|
) |
|
|
add_vertical_space(1) |
|
|
|
|
|
|
|
|
if 'generated_df' not in st.session_state: |
|
|
st.session_state.generated_df = None |
|
|
if 'markdown_output' not in st.session_state: |
|
|
st.session_state.markdown_output = "" |
|
|
|
|
|
|
|
|
st.subheader("1. 데이터 규칙 정하기") |
|
|
col1, col2, col3 = st.columns([2, 2, 1]) |
|
|
|
|
|
with col1: |
|
|
x_input_names_str = st.text_area( |
|
|
"**원인(X) 변수 이름은?** (쉼표로 구분)", |
|
|
value="공부 시간, 수면 시간, 주말 복습 횟수", |
|
|
help="데이터의 원인이 되는 값의 이름들을 쉼표(,)로 구분하여 입력하세요." |
|
|
) |
|
|
|
|
|
with col2: |
|
|
y_input_names_str = st.text_area( |
|
|
"**결과(Y) 변수 이름은?** (쉼표로 구분)", |
|
|
value="시험 점수, 과제 점수", |
|
|
help="원인(X)에 따라 변하는 결과 값의 이름들을 쉼표(,)로 구분하여 입력하세요." |
|
|
) |
|
|
|
|
|
with col3: |
|
|
num_data_rows = st.number_input( |
|
|
"**데이터는 몇 개 만들까요?**", |
|
|
min_value=10, |
|
|
max_value=200, |
|
|
value=50, |
|
|
step=10, |
|
|
help="AI를 학습시키려면 데이터가 충분해야 해요. 10개 이상을 추천해요!" |
|
|
) |
|
|
|
|
|
add_vertical_space(1) |
|
|
generate_button = st.button("🚀 데이터 생성 시작!", type="primary", use_container_width=True) |
|
|
add_vertical_space(2) |
|
|
|
|
|
|
|
|
|
|
|
if generate_button: |
|
|
|
|
|
x_names_list = [name.strip() for name in x_input_names_str.split(',') if name.strip()] |
|
|
y_names_list = [name.strip() for name in y_input_names_str.split(',') if name.strip()] |
|
|
|
|
|
if not x_names_list or not y_names_list: |
|
|
st.warning("⚠️ '원인(X) 변수'와 '결과(Y) 변수'를 하나 이상씩 입력해주세요!") |
|
|
else: |
|
|
with st.spinner("똑똑한 AI가 변수들의 복합적인 관계를 생각하며 데이터를 만들고 있어요... 🤖"): |
|
|
markdown_data = generate_markdown_data(x_names_list, y_names_list, num_data_rows) |
|
|
|
|
|
if markdown_data: |
|
|
df = parse_markdown_to_df(markdown_data) |
|
|
|
|
|
if df is not None and not df.empty: |
|
|
st.session_state.generated_df = df |
|
|
st.session_state.markdown_output = markdown_data |
|
|
st.success("🎉 데이터 생성 완료! 아래에서 확인하고 다운로드하세요.") |
|
|
else: |
|
|
st.error("데이터 생성에 실패했어요. AI가 만든 데이터를 분석할 수 없는 것 같아요.") |
|
|
st.session_state.generated_df = None |
|
|
else: |
|
|
st.error("AI가 데이터를 생성하지 못했어요. 잠시 후 다시 시도해주세요.") |
|
|
st.session_state.generated_df = None |
|
|
|
|
|
|
|
|
if st.session_state.generated_df is not None: |
|
|
df_to_show = st.session_state.generated_df |
|
|
|
|
|
st.subheader("2. 생성된 데이터 미리보기") |
|
|
st.dataframe(df_to_show) |
|
|
|
|
|
st.subheader("3. 데이터 통계 요약") |
|
|
st.markdown("생성된 데이터의 평균, 표준편차, 최소/최대값 등 기술 통계를 확인해보세요.") |
|
|
st.dataframe(df_to_show.describe()) |
|
|
|
|
|
st.subheader("4. CSV 파일로 다운로드하기") |
|
|
st.markdown("이 버튼을 눌러 위에 보이는 데이터를 CSV 파일로 컴퓨터에 저장하세요.") |
|
|
|
|
|
csv = df_to_show.to_csv(index=False).encode('utf-8-sig') |
|
|
|
|
|
st.download_button( |
|
|
label="📥 CSV 파일 다운로드", |
|
|
data=csv, |
|
|
file_name="ai_generated_multivariable_data.csv", |
|
|
mime="text/csv", |
|
|
use_container_width=True |
|
|
) |
|
|
|
|
|
with st.expander("👀 AI가 생성한 원본 Markdown 텍스트가 궁금하다면?"): |
|
|
st.text(st.session_state.markdown_output) |