Spaces:
Running
Running
| import { GoogleGenAI } from "@google/genai"; | |
| import { Node, Edge } from 'reactflow'; | |
| import { NodeData, LayerType } from '../types'; | |
| import { LAYER_DEFINITIONS } from '../constants'; | |
| // Key Management | |
| const getEnvKey = () => { | |
| // In Development: Returns key from .env file if available | |
| // In Production (Docker): Returns undefined/empty because Dockerfile does not inject the ARG | |
| // @ts-ignore | |
| if (typeof import.meta !== 'undefined' && import.meta.env && import.meta.env.VITE_GEMINI_API_KEY) { | |
| // @ts-ignore | |
| return import.meta.env.VITE_GEMINI_API_KEY; | |
| } | |
| return ''; | |
| }; | |
| let userApiKey = typeof window !== 'undefined' ? localStorage.getItem('gemini_api_key') || '' : ''; | |
| const defaultEnvKey = getEnvKey(); | |
| export const setUserApiKey = (key: string) => { | |
| userApiKey = key; | |
| if (typeof window !== 'undefined') { | |
| localStorage.setItem('gemini_api_key', key); | |
| } | |
| }; | |
| export const getUserApiKey = () => userApiKey || defaultEnvKey; | |
| const getAiClient = () => { | |
| const key = getUserApiKey(); | |
| if (!key) throw new Error("API Key is missing. Please add your Gemini API Key."); | |
| return new GoogleGenAI({ apiKey: key }); | |
| }; | |
| const MODEL_NAME = 'gemini-2.5-flash'; | |
| export type AgentStatus = 'idle' | 'architect' | 'critic' | 'refiner' | 'debugger' | 'patcher' | 'complete' | 'error'; | |
| /** | |
| * Internal helper to build the raw specification string. | |
| * This ensures all data is captured before sending to the AI for refinement. | |
| */ | |
| const buildRawPrompt = (nodes: Node<NodeData>[], edges: Edge[]): string => { | |
| // Sort nodes by vertical position to list them in a logical flow | |
| const sortedNodes = [...nodes].sort((a, b) => a.position.y - b.position.y); | |
| let rawSpec = "### Raw Architecture Data\n\n"; | |
| rawSpec += "**1. Nodes (Layers):**\n"; | |
| sortedNodes.forEach(node => { | |
| // Format parameters nicely, handling objects | |
| const params = Object.entries(node.data.params) | |
| .map(([k, v]) => { | |
| if (typeof v === 'object' && v !== null) { | |
| return `${k}=${JSON.stringify(v)}`; | |
| } | |
| // Don't clutter standard output with huge code blocks, handle separately | |
| if (k === 'definition_code' || k === 'imports') return null; | |
| return `${k}=${v}`; | |
| }) | |
| .filter(p => p !== null) | |
| .join(', '); | |
| rawSpec += `- [ID: ${node.id}] TYPE: ${node.data.type} | LABEL: ${node.data.label}\n`; | |
| if (params) { | |
| rawSpec += ` PARAMS: ${params}\n`; | |
| } | |
| // Specific instruction for Custom Layers | |
| if (node.data.type === LayerType.CUSTOM) { | |
| rawSpec += ` CUSTOM_NOTE: Instantiate using class '${node.data.params.class_name}' with args '${node.data.params.args}'.\n`; | |
| if (node.data.params.imports) { | |
| rawSpec += ` CUSTOM_IMPORTS: ${node.data.params.imports}\n`; | |
| } | |
| if (node.data.params.definition_code) { | |
| rawSpec += ` CUSTOM_CODE_DEFINITION:\n${node.data.params.definition_code}\n`; | |
| } | |
| } | |
| }); | |
| rawSpec += "\n**2. Connectivity (Edges):**\n"; | |
| if (edges.length === 0) { | |
| rawSpec += "- No connections defined.\n"; | |
| } else { | |
| edges.forEach(edge => { | |
| const sourceNode = nodes.find(n => n.id === edge.source); | |
| const targetNode = nodes.find(n => n.id === edge.target); | |
| const sourceName = sourceNode ? `${sourceNode.data.label} (ID:${sourceNode.id})` : edge.source; | |
| const targetName = targetNode ? `${targetNode.data.label} (ID:${targetNode.id})` : edge.target; | |
| rawSpec += `- ${sourceName} -> ${targetName}\n`; | |
| }); | |
| } | |
| return rawSpec; | |
| }; | |
| /** | |
| * Generates a polished, professional prompt using the AI. | |
| * It takes the raw hardcoded spec and asks the AI to format it perfectly for a coding LLM. | |
| */ | |
| export const generateRefinedPrompt = async (nodes: Node<NodeData>[], edges: Edge[]): Promise<string> => { | |
| const ai = getAiClient(); | |
| const rawSpec = buildRawPrompt(nodes, edges); | |
| const systemPrompt = ` | |
| You are an expert AI Prompt Engineer for Deep Learning. | |
| Your goal is to take a raw, technical neural network specification and rewrite it into a | |
| perfect, professional, and detailed prompt that another AI (like a coding assistant) could use to write flawless PyTorch code. | |
| Input Raw Specification: | |
| ${rawSpec} | |
| Instructions: | |
| 1. Start the output with: "You are an expert Deep Learning Engineer. Please write a complete, runnable PyTorch model code for the following neural network architecture:" | |
| 2. Create a section "Architecture Specification". | |
| 3. List "1. Layers (Nodes)" cleanly. Include ID, Type, and Parameters. | |
| 4. List "2. Connectivity (Forward Pass Flow)" cleanly. | |
| - Explicitly describe merge points (e.g. "Node X receives inputs from A and B. Handle this merge..."). | |
| - Note specific handling for complex layers like CrossAttention (needs Query + Key/Value) or SAM Decoders. | |
| 5. Create a section "Implementation Requirements" with standard PyTorch best practices (nn.Module, forward method, correct shapes). | |
| 6. If CUSTOM_CODE_DEFINITION or CUSTOM_IMPORTS are present, explicitly instruct the coder to include them verbatim or use them as reference. | |
| 7. Do NOT write the Python code yourself. Write the PROMPT that asks for the code. | |
| 8. Ensure the tone is technical and precise. | |
| Return ONLY the generated prompt text. | |
| `; | |
| try { | |
| const response = await ai.models.generateContent({ | |
| model: MODEL_NAME, | |
| contents: systemPrompt, | |
| }); | |
| return response.text.trim(); | |
| } catch (error) { | |
| console.error("Prompt refinement failed:", error); | |
| throw error; // Re-throw so caller can handle auth errors | |
| } | |
| }; | |
| export const validateArchitecture = async (nodes: Node<NodeData>[], edges: Edge[]): Promise<string> => { | |
| const ai = getAiClient(); | |
| const graphRepresentation = { | |
| nodes: nodes.map(n => ({ | |
| id: n.id, | |
| type: n.data.type, | |
| parameters: n.data.params | |
| })), | |
| edges: edges.map(e => ({ | |
| source: e.source, | |
| target: e.target | |
| })) | |
| }; | |
| const prompt = ` | |
| Analyze this neural network architecture graph for validity. | |
| Graph: ${JSON.stringify(graphRepresentation)} | |
| Check for: | |
| 1. Shape mismatches (e.g., Conv2D output to Linear without Flatten). | |
| 2. Disconnected components. | |
| 3. Logical errors (e.g., MaxPool after Output). | |
| 4. Merge layer correctness (Concat/Add needs multiple inputs). | |
| 5. GenAI correctness (e.g., CrossAttention needs 2 inputs, VLM projection dims match). | |
| Return a concise report. If valid, say "Architecture is valid.". If invalid, list specific errors and suggest fixes. | |
| `; | |
| try { | |
| const response = await ai.models.generateContent({ | |
| model: MODEL_NAME, | |
| contents: prompt, | |
| }); | |
| return response.text; | |
| } catch (error) { | |
| throw error; | |
| } | |
| } | |
| /** | |
| * Gets AI suggestions for improving the architecture. | |
| */ | |
| export const getArchitectureSuggestions = async (nodes: Node<NodeData>[], edges: Edge[]): Promise<string> => { | |
| const ai = getAiClient(); | |
| const graphRepresentation = { | |
| nodes: nodes.map(n => ({ | |
| id: n.id, | |
| type: n.data.type, | |
| parameters: n.data.params, | |
| label: n.data.label | |
| })), | |
| edges: edges.map(e => ({ | |
| source: e.source, | |
| target: e.target | |
| })) | |
| }; | |
| const prompt = ` | |
| You are a Senior Deep Learning Architect. Review the following neural network architecture graph. | |
| Graph Structure: | |
| ${JSON.stringify(graphRepresentation, null, 2)} | |
| Task: Provide 3 to 5 concrete, actionable suggestions to improve this model. | |
| Focus on: | |
| - Modern best practices (e.g., using LayerNorm vs BatchNorm in Transformers, SwiGLU vs ReLU). | |
| - Architecture efficiency and parameter count optimization. | |
| - Potential bottlenecks or vanishing gradient risks. | |
| - Adding residuals or skip connections if the model is deep. | |
| Format the output as a clean bulleted list. Keep it concise, professional and helpful. | |
| `; | |
| try { | |
| const response = await ai.models.generateContent({ | |
| model: MODEL_NAME, | |
| contents: prompt, | |
| }); | |
| return response.text; | |
| } catch (error) { | |
| throw error; | |
| } | |
| } | |
| /** | |
| * Implements the suggestions automatically. | |
| */ | |
| export const implementArchitectureSuggestions = async ( | |
| nodes: Node<NodeData>[], | |
| edges: Edge[], | |
| suggestions: string | |
| ): Promise<{ nodes: any[], edges: any[] }> => { | |
| const ai = getAiClient(); | |
| const graphRepresentation = { | |
| nodes: nodes.map(n => ({ | |
| id: n.id, | |
| type: n.data.type, | |
| parameters: n.data.params, | |
| label: n.data.label, | |
| position: n.position | |
| })), | |
| edges: edges.map(e => ({ | |
| source: e.source, | |
| target: e.target | |
| })) | |
| }; | |
| const prompt = ` | |
| You are a Senior Implementation Engineer. | |
| Task: Apply the following architectural suggestions to the provided graph JSON. | |
| Current Graph: | |
| ${JSON.stringify(graphRepresentation)} | |
| Suggestions to Implement: | |
| "${suggestions}" | |
| Instructions: | |
| 1. Modify the nodes and edges to incorporate the suggestions. | |
| 2. Maintain the layout (x, y positions) as best as possible, offsetting new nodes if added. | |
| 3. Ensure all LayerTypes are valid from the standard schema. | |
| 4. Return the complete, updated JSON with "nodes" and "edges" arrays. | |
| 5. Return ONLY raw JSON. | |
| `; | |
| try { | |
| const response = await ai.models.generateContent({ | |
| model: MODEL_NAME, | |
| contents: prompt, | |
| config: { responseMimeType: "application/json" } | |
| }); | |
| const result = JSON.parse(response.text.trim()); | |
| return sanitizeGraph(result); | |
| } catch (error) { | |
| throw new Error("Failed to implement suggestions: " + (error instanceof Error ? error.message : String(error))); | |
| } | |
| }; | |
| /** | |
| * Helper: Sanitizes the graph JSON from AI. | |
| * Checks if 'type' is valid. If not, converts it to a CUSTOM layer to prevent crashes. | |
| */ | |
| const sanitizeGraph = (graphJson: { nodes: any[], edges: any[] }) => { | |
| if (!graphJson || !graphJson.nodes) return graphJson; | |
| graphJson.nodes = graphJson.nodes.map(node => { | |
| // AI might return type in data.type or top-level type | |
| let rawType = node.data?.type || node.type || 'Identity'; | |
| // Ensure rawType is a string | |
| if (typeof rawType !== 'string') rawType = 'Identity'; | |
| // Check if this type exists in our known definitions | |
| const isValid = Object.values(LayerType).includes(rawType as LayerType) && LAYER_DEFINITIONS[rawType as LayerType]; | |
| if (!isValid) { | |
| console.warn(`Sanitizing unknown layer type: ${rawType}. Converting to CustomLayer.`); | |
| return { | |
| ...node, | |
| type: 'custom', // ReactFlow type | |
| data: { | |
| ...node.data, | |
| type: LayerType.CUSTOM, | |
| label: node.data?.label || rawType, | |
| params: { | |
| ...(node.data?.params || {}), | |
| class_name: rawType, // Store original type name here | |
| args: JSON.stringify(node.data?.params || {}).slice(0, 100) // Rough preserve of args | |
| } | |
| } | |
| }; | |
| } | |
| // Ensure standard structure for valid nodes | |
| return { | |
| ...node, | |
| type: 'custom', | |
| data: { | |
| ...node.data, | |
| type: rawType, | |
| params: node.data?.params || {} | |
| } | |
| }; | |
| }); | |
| return graphJson; | |
| }; | |
| /** | |
| * Multi-agent graph generation. | |
| * 1. Architect Agent: Drafts the layout. | |
| * 2. Critic Agent: Reviews for errors. | |
| * 3. Refiner Agent: Produces final JSON. | |
| */ | |
| export const generateGraphWithAgents = async ( | |
| userPrompt: string, | |
| currentNodes: Node<NodeData>[] = [], | |
| onStatusUpdate: (status: AgentStatus, log: string) => void | |
| ): Promise<{ nodes: any[], edges: any[] } | null> => { | |
| const ai = getAiClient(); | |
| const layerSchema = Object.values(LAYER_DEFINITIONS).map(l => ({ | |
| type: l.type, | |
| params: l.parameters.map(p => ({ name: p.name, type: p.type, options: p.options })) | |
| })); | |
| const schemaStr = JSON.stringify(layerSchema); | |
| // --- Step 1: Architect --- | |
| onStatusUpdate('architect', 'Architect is drafting initial layout...'); | |
| const context = currentNodes.length > 0 | |
| ? `Current Graph Context: ${JSON.stringify(currentNodes.map(n => ({ id: n.id, type: n.data.type, label: n.data.label })))}` | |
| : "Starting from scratch."; | |
| const architectPrompt = ` | |
| Role: Senior Neural Network Architect. | |
| Task: Create a preliminary graph layout (JSON) for the user request. | |
| User Request: "${userPrompt}" | |
| Context: ${context} | |
| Available Layers: ${Object.keys(LAYER_DEFINITIONS).join(', ')} | |
| Schema Reference: ${schemaStr} | |
| Requirements: | |
| 1. Output valid JSON with "nodes" and "edges" arrays. | |
| 2. "nodes": { id, type='custom', position:{x,y}, data:{ type: LayerType, label: string, params: {} } } | |
| 3. Use correct LayerTypes from enum. | |
| 4. Layout nodes vertically (y+150 each step). | |
| 5. Connect edges logically. | |
| 6. If multi-input/output, arrange horizontally. | |
| Return ONLY raw JSON. | |
| `; | |
| let draftJsonStr = ""; | |
| try { | |
| const response = await ai.models.generateContent({ | |
| model: MODEL_NAME, | |
| contents: architectPrompt, | |
| config: { responseMimeType: "application/json" } | |
| }); | |
| draftJsonStr = response.text.trim(); | |
| } catch (e) { | |
| throw e; | |
| } | |
| // --- Step 2: Critic --- | |
| onStatusUpdate('critic', 'Critic is reviewing architecture for flaws...'); | |
| const criticPrompt = ` | |
| Role: Senior Lead Reviewer. | |
| Task: Critique the following neural network architecture draft. | |
| User Request: "${userPrompt}" | |
| Draft Architecture: ${draftJsonStr} | |
| Check for: | |
| - Shape mismatches (e.g. 3D output into 2D input without flattening) | |
| - Logical connection errors | |
| - Missing essential layers (e.g. Activations, Normalization) | |
| - Parameter errors (e.g. kernel size too large) | |
| - Compliance with user request | |
| Output a concise paragraph describing specific improvements needed. If perfect, say "No changes needed". | |
| `; | |
| let critique = ""; | |
| try { | |
| const response = await ai.models.generateContent({ | |
| model: MODEL_NAME, | |
| contents: criticPrompt, | |
| }); | |
| critique = response.text.trim(); | |
| } catch (e) { | |
| console.warn("Critic agent failed, proceeding with draft."); | |
| critique = "No critique available."; | |
| } | |
| // --- Step 3: Refiner --- | |
| onStatusUpdate('refiner', 'Refiner is applying fixes and finalizing...'); | |
| const refinerPrompt = ` | |
| Role: Lead Engineer. | |
| Task: Finalize the JSON architecture based on the critique. | |
| Draft: ${draftJsonStr} | |
| Critique: "${critique}" | |
| Instructions: | |
| 1. Apply the fixes suggested in the critique. | |
| 2. Ensure the JSON structure is strictly { "nodes": [...], "edges": [...] }. | |
| 3. Ensure all node IDs are unique strings. | |
| 4. Ensure parameter values match the schema types. | |
| 5. Ensure "type" in top level node object is always 'custom'. | |
| Return ONLY the final JSON. | |
| `; | |
| try { | |
| const response = await ai.models.generateContent({ | |
| model: MODEL_NAME, | |
| contents: refinerPrompt, | |
| config: { responseMimeType: "application/json" } | |
| }); | |
| const finalJson = JSON.parse(response.text.trim()); | |
| onStatusUpdate('complete', 'Architecture built successfully!'); | |
| // SANITIZE: Prevent UI crashes by handling hallucinated types | |
| return sanitizeGraph(finalJson); | |
| } catch (e) { | |
| throw new Error("Refiner agent failed to parse final JSON."); | |
| } | |
| }; | |
| /** | |
| * Multi-agent code generation. | |
| * Uses the detailed prompt to write, review, and polish PyTorch code. | |
| */ | |
| export const generateCodeWithAgents = async ( | |
| promptText: string, | |
| onStatusUpdate: (status: AgentStatus, log: string) => void | |
| ): Promise<string> => { | |
| const ai = getAiClient(); | |
| // --- Step 1: Coder (Architect) --- | |
| onStatusUpdate('architect', 'Coder Agent is writing initial PyTorch implementation...'); | |
| const coderPrompt = ` | |
| Role: Senior Deep Learning Engineer. | |
| Task: Write complete, runnable PyTorch code based on the following architecture prompt. | |
| Prompt: "${promptText}" | |
| Requirements: | |
| - Use torch.nn.Module | |
| - Include all necessary imports | |
| - Handle forward pass logic exactly as described | |
| - Include a 'if __name__ == "__main__":' block to test with dummy data | |
| Return ONLY the Python code. No markdown formatting. | |
| `; | |
| let draftCode = ""; | |
| try { | |
| const response = await ai.models.generateContent({ | |
| model: MODEL_NAME, | |
| contents: coderPrompt | |
| }); | |
| draftCode = response.text.trim().replace(/```python/g, '').replace(/```/g, ''); | |
| } catch(e) { | |
| throw e; | |
| } | |
| // --- Step 2: Reviewer (Critic) --- | |
| onStatusUpdate('critic', 'Reviewer Agent is analyzing code for bugs and optimization...'); | |
| const reviewPrompt = ` | |
| Role: Code Reviewer. | |
| Task: Review the following PyTorch code for errors, shape mismatches, or style issues. | |
| Code: | |
| ${draftCode} | |
| Original Prompt Request: "${promptText}" | |
| Output a concise critique. If perfect, say "No changes needed". | |
| `; | |
| let critique = ""; | |
| try { | |
| const response = await ai.models.generateContent({ | |
| model: MODEL_NAME, | |
| contents: reviewPrompt | |
| }); | |
| critique = response.text.trim(); | |
| } catch(e) { | |
| critique = "No critique available."; | |
| } | |
| // --- Step 3: Polisher (Refiner) --- | |
| onStatusUpdate('refiner', 'Polisher Agent is finalizing the codebase...'); | |
| const polisherPrompt = ` | |
| Role: Senior Software Engineer. | |
| Task: Refine the PyTorch code based on the critique. | |
| Draft Code: ${draftCode} | |
| Critique: ${critique} | |
| Return ONLY the final Python code. No markdown formatting. | |
| `; | |
| try { | |
| const response = await ai.models.generateContent({ | |
| model: MODEL_NAME, | |
| contents: polisherPrompt | |
| }); | |
| let finalCode = response.text.trim().replace(/```python/g, '').replace(/```/g, ''); | |
| onStatusUpdate('complete', 'Code generation complete!'); | |
| return finalCode; | |
| } catch(e) { | |
| throw new Error("Polisher agent failed to generate code."); | |
| } | |
| }; | |
| /** | |
| * Fix Architecture Errors with Agents (Debugger -> Architect -> Patcher) | |
| */ | |
| export const fixArchitectureErrors = async ( | |
| nodes: Node<NodeData>[], | |
| edges: Edge[], | |
| errorMsg: string, | |
| onStatusUpdate: (status: AgentStatus, log: string) => void | |
| ): Promise<{ nodes: any[], edges: any[] } | null> => { | |
| const ai = getAiClient(); | |
| const graphJson = JSON.stringify({ | |
| nodes: nodes.map(n => ({ id: n.id, type: n.data.type, label: n.data.label, params: n.data.params, position: n.position })), | |
| edges: edges.map(e => ({ source: e.source, target: e.target })) | |
| }); | |
| // --- Step 1: Debugger --- | |
| onStatusUpdate('debugger', 'Debugger Agent is analyzing the error trace...'); | |
| const debuggerPrompt = ` | |
| Role: Senior Systems Debugger. | |
| Task: Analyze the architecture graph and the reported error to pinpoint the root cause. | |
| Graph: ${graphJson} | |
| Error Message: "${errorMsg}" | |
| Output a technical analysis of exactly what is wrong (e.g. "Node A connects to Node B but shapes [X] and [Y] are incompatible"). | |
| `; | |
| let debugAnalysis = ""; | |
| try { | |
| const response = await ai.models.generateContent({ | |
| model: MODEL_NAME, | |
| contents: debuggerPrompt, | |
| }); | |
| debugAnalysis = response.text.trim(); | |
| } catch (e) { | |
| throw e; | |
| } | |
| // --- Step 2: Architect --- | |
| onStatusUpdate('architect', 'Architect Agent is planning the fix...'); | |
| const architectPrompt = ` | |
| Role: Solution Architect. | |
| Task: Propose a specific fix for the identified issue. | |
| Issue Analysis: ${debugAnalysis} | |
| Instructions: | |
| - Determine if nodes need to be added (e.g. Flatten, Reshape), removed, or reconnected. | |
| - Determine if parameters need changing. | |
| Output the plan in clear steps. | |
| `; | |
| let fixPlan = ""; | |
| try { | |
| const response = await ai.models.generateContent({ | |
| model: MODEL_NAME, | |
| contents: architectPrompt, | |
| }); | |
| fixPlan = response.text.trim(); | |
| } catch (e) { | |
| fixPlan = "Apply necessary structural corrections."; | |
| } | |
| // --- Step 3: Patcher --- | |
| onStatusUpdate('patcher', 'Patcher Agent is applying the fix to the graph...'); | |
| const patcherPrompt = ` | |
| Role: DevOps Engineer. | |
| Task: Apply the fix to the graph JSON. | |
| Current Graph: ${graphJson} | |
| Fix Plan: ${fixPlan} | |
| Requirements: | |
| 1. Return the complete, valid JSON with "nodes" and "edges". | |
| 2. Maintain existing node positions where possible, offset new nodes if added. | |
| 3. Ensure all LayerTypes are valid. | |
| 4. Return ONLY raw JSON. | |
| `; | |
| try { | |
| const response = await ai.models.generateContent({ | |
| model: MODEL_NAME, | |
| contents: patcherPrompt, | |
| config: { responseMimeType: "application/json" } | |
| }); | |
| const finalJson = JSON.parse(response.text.trim()); | |
| onStatusUpdate('complete', 'Fix applied successfully!'); | |
| return sanitizeGraph(finalJson); | |
| } catch (e) { | |
| throw new Error("Patcher agent failed to generate valid JSON."); | |
| } | |
| }; |