# app.py # ───────────────────────────────────────────── # Gradio UI for Data Cleaning Environment # Beautiful dashboard with Playfair Display # ───────────────────────────────────────────── import gradio as gr import pandas as pd import json from environment import DataCleaningEnv from models import Action from config import TASK_EASY, TASK_MEDIUM, TASK_HARD import threading from fastapi import FastAPI from fastapi.responses import JSONResponse import uvicorn # ── Global State ────────────────────────────── envs = { TASK_EASY: DataCleaningEnv(TASK_EASY), TASK_MEDIUM: DataCleaningEnv(TASK_MEDIUM), TASK_HARD: DataCleaningEnv(TASK_HARD), } current_difficulty = TASK_EASY step_logs = {TASK_EASY: [], TASK_MEDIUM: [], TASK_HARD: []} # ── Helpers ─────────────────────────────────── def get_env(): return envs[current_difficulty] def df_to_html(df: pd.DataFrame, null_counts: dict) -> str: """Render DataFrame as styled HTML table""" if df is None or len(df) == 0: return "

No data

" rows_html = "" for _, row in df.iterrows(): cells = "" for col in df.columns: val = row[col] is_null = val is None or ( isinstance(val, float) and pd.isna(val) ) if is_null: cells += f"null" elif isinstance(val, (int, float)) and col in ["salary", "salary$"]: if abs(float(val)) > 500000 or float(val) < 0: cells += f"{val}" else: cells += f"{val}" else: cells += f"{val}" rows_html += f"{cells}" headers = "".join([ f"{col}" for col in df.columns ]) return f"""
Live dataset {len(df)} rows · {len(df.columns)} cols
{headers}{rows_html}
""" def issues_to_html(issues: list) -> str: """Render issues as styled pills""" if not issues: return "
No issues detected!
" pills = "" for issue in issues: if "missing" in issue: color = "#FFF5F5"; text = "#B91C1C"; border = "#FECACA" elif "duplicate" in issue: color = "#FFFBEB"; text = "#92400E"; border = "#FDE68A" elif "outlier" in issue: color = "#FFFBEB"; text = "#92400E"; border = "#FDE68A" elif "bad_column" in issue: color = "#EFF6FF"; text = "#1D4ED8"; border = "#BFDBFE" else: color = "#F8FAFC"; text = "#475569"; border = "#E2E8F0" pills += f"
{issue}
" return pills def log_to_html(logs: list) -> str: """Render step log as terminal-style HTML""" if not logs: return "
No steps yet...
" lines = "" for entry in logs[-10:]: reward_color = "#34D399" if entry["reward"] > 0 else "#F87171" lines += f"""
[step {entry['step']}] {entry['action']}
{entry['reward']:+.3f} · {entry['reason'][:50]}
""" return f"
{lines}
" # ── Actions ─────────────────────────────────── def reset_task(difficulty): global current_difficulty current_difficulty = difficulty env = envs[difficulty] obs = env.reset() step_logs[difficulty] = [] df = pd.DataFrame(obs.dataframe) score, _ = env.task.grade() return ( df_to_html(df, obs.null_counts), issues_to_html(obs.issues), log_to_html([]), f"
0.000
", f"
0
", f"
{sum(obs.null_counts.values())}
", f"
0
", ) def run_action(difficulty, action_type, column, strategy, target_type, old_name, new_name, mapping_json): global current_difficulty current_difficulty = difficulty env = envs[difficulty] if not env._initialized: env.reset() # Build parameters params = {} if action_type == "fill_missing": params = {"column": column, "strategy": strategy or "mean"} elif action_type == "drop_duplicates": params = {} elif action_type == "fix_dtype": params = {"column": column, "target_type": target_type or "float"} elif action_type == "rename_column": params = {"old_name": old_name, "new_name": new_name} elif action_type == "remove_outliers": params = {"column": column, "method": strategy or "iqr"} elif action_type == "standardize_values": try: params = {"column": column, "mapping": json.loads(mapping_json)} except Exception: params = {"column": column, "mapping": {}} elif action_type == "submit": params = {} action = Action(action_type=action_type, parameters=params) result = env.step(action) obs = result.observation # Log step step_logs[difficulty].append({ "step": env.step_count, "action": action_type, "reward": result.reward.value, "reason": result.reward.reason, }) df = pd.DataFrame(obs.dataframe) score, _ = env.task.grade() return ( df_to_html(df, obs.null_counts), issues_to_html(obs.issues), log_to_html(step_logs[difficulty]), f"
{score:.3f}
", f"
{env.step_count}
", f"
{sum(obs.null_counts.values())}
", f"
0 else '#EF4444'}'>{result.reward.value:+.3f}
", ) # ── UI ──────────────────────────────────────── CSS = """ @import url('https://fonts.googleapis.com/css2?family=Playfair+Display:wght@700;900&family=DM+Sans:wght@300;400;500&display=swap'); body, .gradio-container { background: #F8FAFC !important; font-family: 'DM Sans', sans-serif !important; } .logo { font-family: 'Playfair Display', serif !important; } h1, h2, h3 { font-family: 'Playfair Display', serif !important; } .gr-button-primary { background: #2563EB !important; border: none !important; } .gr-button { border-radius: 8px !important; font-family: 'DM Sans', sans-serif !important; } """ with gr.Blocks(css=CSS, title="DataClean — OpenEnv") as demo: gr.HTML("""
DataClean
OpenEnv v1.0 · Yashwanth34567
""") with gr.Row(): # ── Left Sidebar ── with gr.Column(scale=1): gr.HTML("
Select Task
") difficulty = gr.Radio( choices=[TASK_EASY, TASK_MEDIUM, TASK_HARD], value=TASK_EASY, label="", ) reset_btn = gr.Button("Reset Task", variant="primary") gr.HTML("
Issues
") issues_display = gr.HTML() # ── Center ── with gr.Column(scale=3): with gr.Row(): score_display = gr.HTML("
") step_display = gr.HTML("
") null_display = gr.HTML("
") reward_display = gr.HTML("
") gr.HTML("
Live Dataset
") table_display = gr.HTML() gr.HTML("
Step Log
") log_display = gr.HTML() # ── Right Panel ── with gr.Column(scale=1): gr.HTML("
Take Action
") action_type = gr.Dropdown( choices=["fill_missing", "drop_duplicates", "fix_dtype", "rename_column", "remove_outliers", "standardize_values", "submit"], value="fill_missing", label="Action", ) column = gr.Textbox(label="Column", placeholder="e.g. age") strategy = gr.Textbox(label="Strategy / Method", placeholder="mean | median | mode | iqr") target_type = gr.Textbox(label="Target Type", placeholder="int | float | str") old_name = gr.Textbox(label="Old Column Name", placeholder="Full Name ") new_name = gr.Textbox(label="New Column Name", placeholder="full_name") mapping = gr.Textbox(label='Mapping (JSON)', placeholder='{"USA": "United States"}') run_btn = gr.Button("Run Action", variant="primary") submit_btn = gr.Button("Submit Task", variant="secondary") # ── Events ─────────────────────────────── outputs = [ table_display, issues_display, log_display, score_display, step_display, null_display, reward_display ] reset_btn.click( fn=reset_task, inputs=[difficulty], outputs=outputs, ) run_btn.click( fn=run_action, inputs=[difficulty, action_type, column, strategy, target_type, old_name, new_name, mapping], outputs=outputs, ) submit_btn.click( fn=lambda d, c, s, t, o, n, m: run_action( d, "submit", c, s, t, o, n, m), inputs=[difficulty, column, strategy, target_type, old_name, new_name, mapping], outputs=outputs, ) # ── FastAPI ─────────────────────────────────── api = FastAPI() api_env = DataCleaningEnv(TASK_EASY) @api.post("/reset") def api_reset(): obs = api_env.reset() return JSONResponse(content=json.loads(obs.model_dump_json())) @api.post("/step") def api_step(action: Action): result = api_env.step(action) return JSONResponse(content=json.loads(result.model_dump_json())) @api.get("/state") def api_state(): try: state = api_env.state() return JSONResponse(content=json.loads(json.dumps(state, default=str))) except Exception as e: # Force initialize if not started api_env.reset() state = api_env.state() return JSONResponse(content=json.loads(json.dumps(state, default=str))) @api.post("/step") def api_step(action: Action): result = api_env.step(action) return JSONResponse(content=result.model_dump()) @api.get("/state") def api_state(): return JSONResponse(content=api_env.state()) @api.get("/health") def health(): return {"status": "ok"} from fastapi.responses import RedirectResponse @api.get("/") def root(): return RedirectResponse(url="/ui") if __name__ == "__main__": app_with_api = gr.mount_gradio_app(api, demo, path="/ui") uvicorn.run(app_with_api, host="0.0.0.0", port=7860)