Spaces:
Running
Running
| import { useState } from "react"; | |
| import ReactMarkdown from "react-markdown"; | |
| import { InputField, Tabs, Dropdown, Button, Radio, Card } from "@elvis/ui"; | |
| import { SUPPORTED_MODES, SUPPORTED_ALGORITHMS, type SettingsUi, type TrajectoryValues } from "./types.ts"; | |
| import usageMarkdown from "./usage.md?raw"; | |
| const DEFAULT_HYPERPARAMETERS = { | |
| "Gradient Descent": { learningRate: "0.1", momentum: "0.0" }, | |
| "Nesterov": { learningRate: "0.1", momentum: "0.0" }, | |
| "Adam": { learningRate: "0.1", beta1: "0.9", beta2: "0.999", epsilon: "1e-8" }, | |
| "Adagrad": { learningRate: "0.1", epsilon: "1e-8" }, | |
| "RMSProp": { learningRate: "0.1", beta: "0.9", epsilon: "1e-8" }, | |
| "Adadelta": { beta: "0.9", epsilon: "1e-3" }, | |
| "Newton": {} | |
| } | |
| const UNIVARIATE_FUNCTION_OPTIONS = { | |
| "--Custom--": "x^2", | |
| "Quadratic": "x^2", | |
| "Local Minima": "x^2 - 0.1cos(20x)", | |
| "Plateau": "x^4", | |
| } | |
| const BIVARIATE_FUNCTION_OPTIONS = { | |
| "--Custom--": "x^2 + 3y^2", | |
| "Quadratic": "x^2 + 3y^2", | |
| "Small": "(0.05x)^2 + 3(0.05y)^2", | |
| "Ackley": "-20exp(-0.2 sqrt(0.5(x^2 + y^2))) - exp(0.5(cos(2 pi x) + cos(2 pi y))) + e + 20", | |
| "Rasteringin": "20 + (x^2 - 10cos(2 pi x)) + (y^2 - 10cos(2 pi y))", | |
| "Rosenbrock": "(1 - x)^2 + 100(y - x^2)^2", | |
| } | |
| interface SidebarProps { | |
| settings: SettingsUi, | |
| setSettings: (settings: SettingsUi) => void, | |
| onRandomInitialPoint: () => void, | |
| trajectoryValues?: TrajectoryValues | null, | |
| onReset?: () => void, | |
| onNextStep?: () => void, | |
| onPrevStep?: () => void, | |
| onPlay?: () => void, | |
| onPause?: () => void, | |
| } | |
| export default function Sidebar({ | |
| settings, | |
| setSettings, | |
| onRandomInitialPoint, | |
| trajectoryValues, | |
| onReset, | |
| onNextStep, | |
| onPrevStep, | |
| // onPlay, | |
| // onPause, | |
| }: SidebarProps) { | |
| const tabs = ["Settings", "Optimize", "Usage"] as const; | |
| const [activeTab, setActiveTab] = useState<(typeof tabs)[number]>("Settings"); | |
| function updateSettings(key: keyof SettingsUi, value: string) { | |
| if (key === "algorithm") { | |
| const defaults = DEFAULT_HYPERPARAMETERS[value as keyof typeof DEFAULT_HYPERPARAMETERS]; | |
| setSettings({ ...settings, algorithm: value as SettingsUi["algorithm"], ...defaults }); | |
| } else { | |
| setSettings({ ...settings, [key]: value }); | |
| } | |
| } | |
| const [functionOption, setFunctionOption] = useState<string>("--Custom--"); | |
| function handleFunctionOptionChange(option: string) { | |
| setFunctionOption(option); | |
| const expr = settings.mode === "Bivariate" ? BIVARIATE_FUNCTION_OPTIONS[option as keyof typeof BIVARIATE_FUNCTION_OPTIONS] : UNIVARIATE_FUNCTION_OPTIONS[option as keyof typeof UNIVARIATE_FUNCTION_OPTIONS]; | |
| updateSettings("functionExpr", expr); | |
| } | |
| function handleModeChange(mode: SettingsUi["mode"]) { | |
| // When changing modes, reset function to Quadratic as some options are mode-specific | |
| const newFunctionOption = "--Custom--"; | |
| const expr = mode === "Bivariate" | |
| ? BIVARIATE_FUNCTION_OPTIONS[newFunctionOption as keyof typeof BIVARIATE_FUNCTION_OPTIONS] | |
| : UNIVARIATE_FUNCTION_OPTIONS[newFunctionOption as keyof typeof UNIVARIATE_FUNCTION_OPTIONS]; | |
| // should update mode and functionExpr together as one may override the other | |
| setSettings({ ...settings, mode, functionExpr: expr }); | |
| setFunctionOption(newFunctionOption); | |
| } | |
| function getLastValue<T>(values?: T[] | null): T | null { | |
| return values && values.length > 0 ? values[values.length - 1] : null; | |
| } | |
| const currentX = getLastValue(trajectoryValues?.x); | |
| const currentY = getLastValue(trajectoryValues?.y); | |
| // univariate only | |
| const currentDerivative = getLastValue(trajectoryValues?.derivative); | |
| const currentSecondDerivative = getLastValue(trajectoryValues?.secondDerivative); | |
| // bivariate only | |
| const currentZ = getLastValue(trajectoryValues?.z); | |
| const currentGradient = getLastValue(trajectoryValues?.gradient); | |
| const currentHessian = getLastValue(trajectoryValues?.hessian); | |
| return ( | |
| <Card className="flex flex-col h-full p-4 gap-2 min-w-0 min-h-0 overflow-auto"> | |
| <Tabs tabs={tabs} activeTab={activeTab} onChange={setActiveTab} /> | |
| {/* Tab content */} | |
| <div className="flex flex-col gap-4 min-h-0 min-w-0 overflow-auto"> | |
| {activeTab === "Settings" && ( | |
| <> | |
| <Radio | |
| label="Problem Type" | |
| options={SUPPORTED_MODES} | |
| activeOption={settings.mode} | |
| onChange={handleModeChange} | |
| /> | |
| <Dropdown | |
| label="Function" | |
| options={ | |
| settings.mode === "Bivariate" | |
| ? Object.keys(BIVARIATE_FUNCTION_OPTIONS) | |
| : Object.keys(UNIVARIATE_FUNCTION_OPTIONS) | |
| } | |
| activeOption={functionOption} | |
| onChange={handleFunctionOptionChange} | |
| /> | |
| <InputField | |
| label="Function Expression" | |
| value={settings.functionExpr} | |
| onChange={(value) => updateSettings("functionExpr", value)} | |
| readonly={functionOption !== "--Custom--"} | |
| rows={3} | |
| /> | |
| <Dropdown | |
| label="Algorithm" | |
| options={SUPPORTED_ALGORITHMS} | |
| activeOption={settings.algorithm} | |
| onChange={(value: SettingsUi["algorithm"]) => updateSettings("algorithm", value)} | |
| /> | |
| <div className={`${settings.mode === "Bivariate" ? "grid grid-cols-2 gap-2" : ""}`}> | |
| <InputField | |
| label="Initial X" | |
| value={settings.x0} | |
| onChange={(value) => updateSettings("x0", value)} | |
| /> | |
| {settings.mode === "Bivariate" && ( | |
| <InputField | |
| label="Initial Y" | |
| value={settings.y0 || ""} | |
| onChange={(value) => updateSettings("y0", value)} | |
| /> | |
| )} | |
| </div> | |
| {/* todo button for random init */} | |
| <Button | |
| label="Random Initial Point" | |
| onClick={onRandomInitialPoint} | |
| /> | |
| <div className="grid grid-cols-2 gap-2"> | |
| {["Gradient Descent", "Nesterov", "Adam", "Adagrad", "RMSProp", "Adadelta"].includes(settings.algorithm) && ( | |
| <> | |
| <InputField | |
| label="Learning Rate" | |
| value={settings.learningRate} | |
| onChange={(value) => updateSettings("learningRate", value)} | |
| /> | |
| </> | |
| )} | |
| {["Gradient Descent", "Nesterov"].includes(settings.algorithm) && ( | |
| <> | |
| <InputField | |
| label="Momentum" | |
| value={settings.momentum} | |
| onChange={(value) => updateSettings("momentum", value)} | |
| /> | |
| </> | |
| )} | |
| {settings.algorithm === "Adam" && ( | |
| <> | |
| <InputField | |
| label="Beta 1" | |
| value={settings.beta1} | |
| onChange={(value) => updateSettings("beta1", value)} | |
| /> | |
| <InputField | |
| label="Beta 2" | |
| value={settings.beta2} | |
| onChange={(value) => updateSettings("beta2", value)} | |
| /> | |
| </> | |
| )} | |
| {["RMSProp", "Adadelta"].includes(settings.algorithm) && ( | |
| <> | |
| <InputField | |
| label="Beta" | |
| value={settings.beta} | |
| onChange={(value) => updateSettings("beta", value)} | |
| /> | |
| </> | |
| )} | |
| {["Adam", "Adagrad", "RMSProp", "Adadelta"].includes(settings.algorithm) && ( | |
| <> | |
| <InputField | |
| label="Epsilon" | |
| value={settings.epsilon} | |
| onChange={(value) => updateSettings("epsilon", value)} | |
| /> | |
| </> | |
| )} | |
| </div> | |
| </> | |
| )} | |
| {activeTab === "Optimize" && ( | |
| <> | |
| <div className="hidden lg:flex flex-col gap-4 min-h-0 min-w-0 overflow-auto"> | |
| { settings.mode === "Univariate" && ( | |
| <> | |
| <InputField | |
| label="Current X" | |
| value={currentX !== null ? currentX.toFixed(4) : ""} | |
| readonly | |
| /> | |
| <InputField | |
| label="Current Y" | |
| value={currentY !== null ? currentY.toFixed(4) : ""} | |
| readonly | |
| /> | |
| <InputField | |
| label="Current Derivative" | |
| value={currentDerivative !== null ? currentDerivative.toFixed(4) : ""} | |
| readonly | |
| /> | |
| <InputField | |
| label="Current Second Derivative" | |
| value={currentSecondDerivative !== null ? currentSecondDerivative.toFixed(4) : ""} | |
| readonly | |
| /> | |
| </> | |
| )} | |
| { settings.mode === "Bivariate" && ( | |
| <> | |
| <InputField | |
| label="Current X" | |
| value={currentX !== null ? currentX.toFixed(4) : ""} | |
| readonly | |
| /> | |
| <InputField | |
| label="Current Y" | |
| value={currentY !== null ? currentY.toFixed(4) : ""} | |
| readonly | |
| /> | |
| <InputField | |
| label="Current Z" | |
| value={currentZ !== null ? currentZ.toFixed(4) : ""} | |
| readonly | |
| /> | |
| <InputField | |
| label="Current Gradient" | |
| value={currentGradient !== null ? `[${currentGradient.map(v => v.toFixed(4)).join(", ")}]` : ""} | |
| readonly | |
| /> | |
| <InputField | |
| label="Current Hessian" | |
| value={currentHessian !== null ? `[${currentHessian.map(row => `[${row.map(v => v.toFixed(4)).join(", ")}]`).join(", ")}]` : ""} | |
| readonly | |
| /> | |
| </> | |
| )} | |
| </div> | |
| <div className="grid grid-cols-2 gap-2"> | |
| <Button label="Next Step" onClick={onNextStep}/ > | |
| <Button label="Previous Step" onClick={onPrevStep} /> | |
| </div> | |
| <Button label="Reset" onClick={onReset}/> | |
| {/* <div className="grid grid-cols-2 gap-2"> | |
| <Button label="Play" onClick={onPlay} /> | |
| <InputField label="Steps per second" /> | |
| </div> | |
| <div className="grid grid-cols-2 gap-2"> | |
| <Button label="Pause" onClick={onPause} /> | |
| </div> */} | |
| </> | |
| )} | |
| {activeTab === "Usage" && ( | |
| <ReactMarkdown | |
| components={{ | |
| h1: ({ children }) => <h1 className="text-2xl font-bold mt-4 mb-2">{children}</h1>, | |
| h2: ({ children }) => <h2 className="text-xl font-semibold mt-4 mb-2">{children}</h2>, | |
| h3: ({ children }) => <h3 className="text-lg font-semibold mt-4 mb-2">{children}</h3>, | |
| p: ({ children }) => <p className="leading-6 mb-3 last:mb-0">{children}</p>, | |
| ul: ({ children }) => <ul className="list-disc pl-5 mb-3">{children}</ul>, | |
| ol: ({ children }) => <ol className="list-decimal pl-5 mb-3">{children}</ol>, | |
| li: ({ children }) => <li className="mb-1">{children}</li>, | |
| }} | |
| > | |
| {usageMarkdown} | |
| </ReactMarkdown> | |
| )} | |
| </div> | |
| </Card> | |
| ); | |
| } | |