wuhp commited on
Commit
f3b8c5d
·
verified ·
1 Parent(s): 01bb374

Update services/geminiService.ts

Browse files
Files changed (1) hide show
  1. services/geminiService.ts +404 -276
services/geminiService.ts CHANGED
@@ -1,8 +1,67 @@
1
- import { GoogleGenAI } from "@google/genai";
2
  import { Node, Edge } from 'reactflow';
3
  import { NodeData, LayerType } from '../types';
4
  import { LAYER_DEFINITIONS } from '../constants';
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  // Key Management
7
  const getEnvKey = () => {
8
  // In Development: Returns key from .env file if available
@@ -33,7 +92,7 @@ const getAiClient = () => {
33
  return new GoogleGenAI({ apiKey: key });
34
  };
35
 
36
- const MODEL_NAME = 'gemini-2.5-flash';
37
 
38
  export type AgentStatus = 'idle' | 'architect' | 'critic' | 'refiner' | 'debugger' | 'patcher' | 'complete' | 'error';
39
 
@@ -100,46 +159,44 @@ const buildRawPrompt = (nodes: Node<NodeData>[], edges: Edge[]): string => {
100
  * Generates a polished, professional prompt using the AI.
101
  * It takes the raw hardcoded spec and asks the AI to format it perfectly for a coding LLM.
102
  */
103
- export const generateRefinedPrompt = async (nodes: Node<NodeData>[], edges: Edge[]): Promise<string> => {
104
  const ai = getAiClient();
105
  const rawSpec = buildRawPrompt(nodes, edges);
106
 
107
- const systemPrompt = `
108
- You are an expert AI Prompt Engineer for Deep Learning.
109
- Your goal is to take a raw, technical neural network specification and rewrite it into a
110
- perfect, professional, and detailed prompt that another AI (like a coding assistant) could use to write flawless PyTorch code.
111
-
112
- Input Raw Specification:
113
- ${rawSpec}
114
-
115
- Instructions:
116
- 1. Start the output with: "You are an expert Deep Learning Engineer. Please write a complete, runnable PyTorch model code for the following neural network architecture:"
117
- 2. Create a section "Architecture Specification".
118
- 3. List "1. Layers (Nodes)" cleanly. Include ID, Type, and Parameters.
119
- 4. List "2. Connectivity (Forward Pass Flow)" cleanly.
120
- - Explicitly describe merge points (e.g. "Node X receives inputs from A and B. Handle this merge...").
121
- - Note specific handling for complex layers like CrossAttention (needs Query + Key/Value) or SAM Decoders.
122
- 5. Create a section "Implementation Requirements" with standard PyTorch best practices (nn.Module, forward method, correct shapes).
123
- 6. If CUSTOM_CODE_DEFINITION or CUSTOM_IMPORTS are present, explicitly instruct the coder to include them verbatim or use them as reference.
124
- 7. Do NOT write the Python code yourself. Write the PROMPT that asks for the code.
125
- 8. Ensure the tone is technical and precise.
126
-
127
- Return ONLY the generated prompt text.
128
- `;
129
-
130
  try {
131
  const response = await ai.models.generateContent({
132
- model: MODEL_NAME,
133
- contents: systemPrompt,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  });
135
  return response.text.trim();
136
  } catch (error) {
137
  console.error("Prompt refinement failed:", error);
138
- throw error; // Re-throw so caller can handle auth errors
139
  }
140
  };
141
 
142
- export const validateArchitecture = async (nodes: Node<NodeData>[], edges: Edge[]): Promise<string> => {
143
  const ai = getAiClient();
144
  const graphRepresentation = {
145
  nodes: nodes.map(n => ({
@@ -153,24 +210,24 @@ export const validateArchitecture = async (nodes: Node<NodeData>[], edges: Edge[
153
  }))
154
  };
155
 
156
- const prompt = `
157
- Analyze this neural network architecture graph for validity.
158
- Graph: ${JSON.stringify(graphRepresentation)}
159
-
160
- Check for:
161
- 1. Shape mismatches (e.g., Conv2D output to Linear without Flatten).
162
- 2. Disconnected components.
163
- 3. Logical errors (e.g., MaxPool after Output).
164
- 4. Merge layer correctness (Concat/Add needs multiple inputs).
165
- 5. GenAI correctness (e.g., CrossAttention needs 2 inputs, VLM projection dims match).
166
-
167
- Return a concise report. If valid, say "Architecture is valid.". If invalid, list specific errors and suggest fixes.
168
- `;
169
-
170
  try {
171
  const response = await ai.models.generateContent({
172
- model: MODEL_NAME,
173
- contents: prompt,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  });
175
  return response.text;
176
  } catch (error) {
@@ -181,7 +238,7 @@ export const validateArchitecture = async (nodes: Node<NodeData>[], edges: Edge[
181
  /**
182
  * Gets AI suggestions for improving the architecture.
183
  */
184
- export const getArchitectureSuggestions = async (nodes: Node<NodeData>[], edges: Edge[]): Promise<string> => {
185
  const ai = getAiClient();
186
  const graphRepresentation = {
187
  nodes: nodes.map(n => ({
@@ -196,26 +253,24 @@ export const getArchitectureSuggestions = async (nodes: Node<NodeData>[], edges:
196
  }))
197
  };
198
 
199
- const prompt = `
200
- You are a Senior Deep Learning Architect. Review the following neural network architecture graph.
201
-
202
- Graph Structure:
203
- ${JSON.stringify(graphRepresentation, null, 2)}
204
-
205
- Task: Provide 3 to 5 concrete, actionable suggestions to improve this model.
206
- Focus on:
207
- - Modern best practices (e.g., using LayerNorm vs BatchNorm in Transformers, SwiGLU vs ReLU).
208
- - Architecture efficiency and parameter count optimization.
209
- - Potential bottlenecks or vanishing gradient risks.
210
- - Adding residuals or skip connections if the model is deep.
211
-
212
- Format the output as a clean bulleted list. Keep it concise, professional and helpful.
213
- `;
214
-
215
  try {
216
  const response = await ai.models.generateContent({
217
- model: MODEL_NAME,
218
- contents: prompt,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
  });
220
  return response.text;
221
  } catch (error) {
@@ -229,7 +284,8 @@ export const getArchitectureSuggestions = async (nodes: Node<NodeData>[], edges:
229
  export const implementArchitectureSuggestions = async (
230
  nodes: Node<NodeData>[],
231
  edges: Edge[],
232
- suggestions: string
 
233
  ): Promise<{ nodes: any[], edges: any[] }> => {
234
  const ai = getAiClient();
235
 
@@ -247,29 +303,23 @@ export const implementArchitectureSuggestions = async (
247
  }))
248
  };
249
 
250
- const prompt = `
251
- You are a Senior Implementation Engineer.
252
- Task: Apply the following architectural suggestions to the provided graph JSON.
253
-
254
- Current Graph:
255
- ${JSON.stringify(graphRepresentation)}
256
-
257
- Suggestions to Implement:
258
- "${suggestions}"
259
-
260
- Instructions:
261
- 1. Modify the nodes and edges to incorporate the suggestions.
262
- 2. Maintain the layout (x, y positions) as best as possible, offsetting new nodes if added.
263
- 3. Ensure all LayerTypes are valid from the standard schema.
264
- 4. Return the complete, updated JSON with "nodes" and "edges" arrays.
265
- 5. Return ONLY raw JSON.
266
- `;
267
-
268
  try {
269
  const response = await ai.models.generateContent({
270
- model: MODEL_NAME,
271
- contents: prompt,
272
- config: { responseMimeType: "application/json" }
 
 
 
 
 
 
 
 
 
 
 
 
273
  });
274
  const result = JSON.parse(response.text.trim());
275
  return sanitizeGraph(result);
@@ -284,47 +334,72 @@ export const implementArchitectureSuggestions = async (
284
  * Checks if 'type' is valid. If not, converts it to a CUSTOM layer to prevent crashes.
285
  */
286
  const sanitizeGraph = (graphJson: { nodes: any[], edges: any[] }) => {
287
- if (!graphJson || !graphJson.nodes) return graphJson;
288
 
289
- graphJson.nodes = graphJson.nodes.map(node => {
290
- // AI might return type in data.type or top-level type
291
- let rawType = node.data?.type || node.type || 'Identity';
292
-
293
- // Ensure rawType is a string
294
- if (typeof rawType !== 'string') rawType = 'Identity';
295
-
296
- // Check if this type exists in our known definitions
297
- const isValid = Object.values(LayerType).includes(rawType as LayerType) && LAYER_DEFINITIONS[rawType as LayerType];
298
-
299
- if (!isValid) {
300
- console.warn(`Sanitizing unknown layer type: ${rawType}. Converting to CustomLayer.`);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
301
  return {
302
  ...node,
303
- type: 'custom', // ReactFlow type
304
  data: {
305
  ...node.data,
306
- type: LayerType.CUSTOM,
307
- label: node.data?.label || rawType,
308
- params: {
309
- ...(node.data?.params || {}),
310
- class_name: rawType, // Store original type name here
311
- args: JSON.stringify(node.data?.params || {}).slice(0, 100) // Rough preserve of args
312
- }
313
  }
314
  };
315
- }
316
-
317
- // Ensure standard structure for valid nodes
318
- return {
319
- ...node,
320
- type: 'custom',
321
- data: {
322
- ...node.data,
323
- type: rawType,
324
- params: node.data?.params || {}
 
 
 
 
325
  }
326
- };
327
- });
 
 
 
 
 
 
 
 
328
 
329
  return graphJson;
330
  };
@@ -339,7 +414,8 @@ const sanitizeGraph = (graphJson: { nodes: any[], edges: any[] }) => {
339
  export const generateGraphWithAgents = async (
340
  userPrompt: string,
341
  currentNodes: Node<NodeData>[] = [],
342
- onStatusUpdate: (status: AgentStatus, log: string) => void
 
343
  ): Promise<{ nodes: any[], edges: any[] } | null> => {
344
  const ai = getAiClient();
345
 
@@ -356,34 +432,34 @@ export const generateGraphWithAgents = async (
356
  ? `Current Graph Context: ${JSON.stringify(currentNodes.map(n => ({ id: n.id, type: n.data.type, label: n.data.label })))}`
357
  : "Starting from scratch.";
358
 
359
- const architectPrompt = `
360
- Role: Senior Neural Network Architect.
361
- Task: Create a preliminary graph layout (JSON) for the user request.
362
- User Request: "${userPrompt}"
363
- Context: ${context}
364
-
365
- Available Layers: ${Object.keys(LAYER_DEFINITIONS).join(', ')}
366
- Schema Reference: ${schemaStr}
367
-
368
- Requirements:
369
- 1. Output valid JSON with "nodes" and "edges" arrays.
370
- 2. "nodes": { id, type='custom', position:{x,y}, data:{ type: LayerType, label: string, params: {} } }
371
- 3. Use correct LayerTypes from enum.
372
- 4. Layout nodes vertically (y+150 each step).
373
- 5. Connect edges logically.
374
- 6. If multi-input/output, arrange horizontally.
375
-
376
- Return ONLY raw JSON.
377
- `;
378
 
379
  let draftJsonStr = "";
 
380
  try {
381
- const response = await ai.models.generateContent({
382
- model: MODEL_NAME,
383
- contents: architectPrompt,
384
- config: { responseMimeType: "application/json" }
 
 
 
 
385
  });
386
- draftJsonStr = response.text.trim();
387
  } catch (e) {
388
  throw e;
389
  }
@@ -391,29 +467,35 @@ export const generateGraphWithAgents = async (
391
  // --- Step 2: Critic ---
392
  onStatusUpdate('critic', 'Critic is reviewing architecture for flaws...');
393
 
394
- const criticPrompt = `
395
- Role: Senior Lead Reviewer.
396
- Task: Critique the following neural network architecture draft.
397
- User Request: "${userPrompt}"
398
- Draft Architecture: ${draftJsonStr}
399
-
400
- Check for:
401
- - Shape mismatches (e.g. 3D output into 2D input without flattening)
402
- - Logical connection errors
403
- - Missing essential layers (e.g. Activations, Normalization)
404
- - Parameter errors (e.g. kernel size too large)
405
- - Compliance with user request
406
-
407
- Output a concise paragraph describing specific improvements needed. If perfect, say "No changes needed".
408
- `;
409
 
410
  let critique = "";
 
411
  try {
412
- const response = await ai.models.generateContent({
413
- model: MODEL_NAME,
414
- contents: criticPrompt,
 
 
 
 
 
 
 
 
 
415
  });
416
- critique = response.text.trim();
417
  } catch (e) {
418
  console.warn("Critic agent failed, proceeding with draft.");
419
  critique = "No critique available.";
@@ -422,33 +504,35 @@ export const generateGraphWithAgents = async (
422
  // --- Step 3: Refiner ---
423
  onStatusUpdate('refiner', 'Refiner is applying fixes and finalizing...');
424
 
425
- const refinerPrompt = `
426
- Role: Lead Engineer.
427
- Task: Finalize the JSON architecture based on the critique.
428
-
429
- Draft: ${draftJsonStr}
430
- Critique: "${critique}"
431
-
432
- Instructions:
433
- 1. Apply the fixes suggested in the critique.
434
- 2. Ensure the JSON structure is strictly { "nodes": [...], "edges": [...] }.
435
- 3. Ensure all node IDs are unique strings.
436
- 4. Ensure parameter values match the schema types.
437
- 5. Ensure "type" in top level node object is always 'custom'.
438
-
439
- Return ONLY the final JSON.
440
- `;
441
 
442
  try {
443
  const response = await ai.models.generateContent({
444
- model: MODEL_NAME,
445
- contents: refinerPrompt,
446
- config: { responseMimeType: "application/json" }
 
 
 
 
 
 
 
 
447
  });
448
- const finalJson = JSON.parse(response.text.trim());
449
  onStatusUpdate('complete', 'Architecture built successfully!');
450
 
451
- // SANITIZE: Prevent UI crashes by handling hallucinated types
452
  return sanitizeGraph(finalJson);
453
  } catch (e) {
454
  throw new Error("Refiner agent failed to parse final JSON.");
@@ -461,80 +545,100 @@ export const generateGraphWithAgents = async (
461
  */
462
  export const generateCodeWithAgents = async (
463
  promptText: string,
464
- onStatusUpdate: (status: AgentStatus, log: string) => void
 
465
  ): Promise<string> => {
466
  const ai = getAiClient();
467
 
468
- // --- Step 1: Coder (Architect) ---
469
  onStatusUpdate('architect', 'Coder Agent is writing initial PyTorch implementation...');
470
 
471
- const coderPrompt = `
472
- Role: Senior Deep Learning Engineer.
473
- Task: Write complete, runnable PyTorch code based on the following architecture prompt.
474
- Prompt: "${promptText}"
475
-
476
- Requirements:
477
- - Use torch.nn.Module
478
- - Include all necessary imports
479
- - Handle forward pass logic exactly as described
480
- - Include a 'if __name__ == "__main__":' block to test with dummy data
481
-
482
- Return ONLY the Python code. No markdown formatting.
483
- `;
484
 
485
  let draftCode = "";
486
  try {
487
  const response = await ai.models.generateContent({
488
- model: MODEL_NAME,
489
- contents: coderPrompt
 
 
 
 
490
  });
491
  draftCode = response.text.trim().replace(/```python/g, '').replace(/```/g, '');
492
  } catch(e) {
493
  throw e;
494
  }
495
 
496
- // --- Step 2: Reviewer (Critic) ---
497
  onStatusUpdate('critic', 'Reviewer Agent is analyzing code for bugs and optimization...');
498
- const reviewPrompt = `
499
- Role: Code Reviewer.
500
- Task: Review the following PyTorch code for errors, shape mismatches, or style issues.
501
- Code:
502
- ${draftCode}
503
-
504
- Original Prompt Request: "${promptText}"
505
-
506
- Output a concise critique. If perfect, say "No changes needed".
507
- `;
508
-
 
509
  let critique = "";
510
  try {
511
  const response = await ai.models.generateContent({
512
- model: MODEL_NAME,
513
- contents: reviewPrompt
 
 
 
 
 
 
 
 
514
  });
515
  critique = response.text.trim();
516
  } catch(e) {
517
  critique = "No critique available.";
518
  }
519
 
520
- // --- Step 3: Polisher (Refiner) ---
521
  onStatusUpdate('refiner', 'Polisher Agent is finalizing the codebase...');
522
- const polisherPrompt = `
523
- Role: Senior Software Engineer.
524
- Task: Refine the PyTorch code based on the critique.
525
-
526
- Draft Code: ${draftCode}
527
- Critique: ${critique}
528
-
529
- Return ONLY the final Python code. No markdown formatting.
530
- `;
 
 
531
 
532
  try {
533
  const response = await ai.models.generateContent({
534
- model: MODEL_NAME,
535
- contents: polisherPrompt
 
 
 
 
 
536
  });
537
- let finalCode = response.text.trim().replace(/```python/g, '').replace(/```/g, '');
538
  onStatusUpdate('complete', 'Code generation complete!');
539
  return finalCode;
540
  } catch(e) {
@@ -549,7 +653,8 @@ export const fixArchitectureErrors = async (
549
  nodes: Node<NodeData>[],
550
  edges: Edge[],
551
  errorMsg: string,
552
- onStatusUpdate: (status: AgentStatus, log: string) => void
 
553
  ): Promise<{ nodes: any[], edges: any[] } | null> => {
554
  const ai = getAiClient();
555
  const graphJson = JSON.stringify({
@@ -560,21 +665,24 @@ export const fixArchitectureErrors = async (
560
  // --- Step 1: Debugger ---
561
  onStatusUpdate('debugger', 'Debugger Agent is analyzing the error trace...');
562
 
563
- const debuggerPrompt = `
564
- Role: Senior Systems Debugger.
565
- Task: Analyze the architecture graph and the reported error to pinpoint the root cause.
566
-
567
- Graph: ${graphJson}
568
- Error Message: "${errorMsg}"
569
-
570
- Output a technical analysis of exactly what is wrong (e.g. "Node A connects to Node B but shapes [X] and [Y] are incompatible").
571
- `;
572
-
573
  let debugAnalysis = "";
574
  try {
575
  const response = await ai.models.generateContent({
576
- model: MODEL_NAME,
577
- contents: debuggerPrompt,
 
 
 
 
 
 
 
 
 
 
 
 
 
578
  });
579
  debugAnalysis = response.text.trim();
580
  } catch (e) {
@@ -584,57 +692,77 @@ export const fixArchitectureErrors = async (
584
  // --- Step 2: Architect ---
585
  onStatusUpdate('architect', 'Architect Agent is planning the fix...');
586
 
587
- const architectPrompt = `
588
- Role: Solution Architect.
589
- Task: Propose a specific fix for the identified issue.
590
-
591
- Issue Analysis: ${debugAnalysis}
592
-
593
- Instructions:
594
- - Determine if nodes need to be added (e.g. Flatten, Reshape), removed, or reconnected.
595
- - Determine if parameters need changing.
596
-
597
- Output the plan in clear steps.
598
- `;
599
-
600
  let fixPlan = "";
601
  try {
602
  const response = await ai.models.generateContent({
603
- model: MODEL_NAME,
604
- contents: architectPrompt,
 
 
 
 
 
 
 
 
 
 
 
605
  });
606
  fixPlan = response.text.trim();
607
  } catch (e) {
608
  fixPlan = "Apply necessary structural corrections.";
609
  }
610
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
611
  // --- Step 3: Patcher ---
612
  onStatusUpdate('patcher', 'Patcher Agent is applying the fix to the graph...');
613
 
614
- const patcherPrompt = `
615
- Role: DevOps Engineer.
616
- Task: Apply the fix to the graph JSON.
617
-
618
- Current Graph: ${graphJson}
619
- Fix Plan: ${fixPlan}
620
-
621
- Requirements:
622
- 1. Return the complete, valid JSON with "nodes" and "edges".
623
- 2. Maintain existing node positions where possible, offset new nodes if added.
624
- 3. Ensure all LayerTypes are valid.
625
- 4. Return ONLY raw JSON.
626
- `;
627
-
628
  try {
629
  const response = await ai.models.generateContent({
630
- model: MODEL_NAME,
631
- contents: patcherPrompt,
632
- config: { responseMimeType: "application/json" }
 
 
 
 
 
 
 
 
 
 
 
 
633
  });
634
- const finalJson = JSON.parse(response.text.trim());
635
  onStatusUpdate('complete', 'Fix applied successfully!');
636
  return sanitizeGraph(finalJson);
637
  } catch (e) {
638
  throw new Error("Patcher agent failed to generate valid JSON.");
639
  }
640
- };
 
1
+ import { GoogleGenAI, Type } from "@google/genai";
2
  import { Node, Edge } from 'reactflow';
3
  import { NodeData, LayerType } from '../types';
4
  import { LAYER_DEFINITIONS } from '../constants';
5
 
6
+ // --- SCHEMAS ---
7
+ const graphResponseSchema = {
8
+ type: Type.OBJECT,
9
+ properties: {
10
+ reasoning: {
11
+ type: Type.STRING,
12
+ description: "Brief explanation of architectural choices or fixes applied."
13
+ },
14
+ nodes: {
15
+ type: Type.ARRAY,
16
+ items: {
17
+ type: Type.OBJECT,
18
+ properties: {
19
+ id: { type: Type.STRING },
20
+ type: { type: Type.STRING, description: "Must be 'custom' for ReactFlow compatibility." },
21
+ position: {
22
+ type: Type.OBJECT,
23
+ properties: {
24
+ x: { type: Type.NUMBER },
25
+ y: { type: Type.NUMBER }
26
+ },
27
+ required: ["x", "y"]
28
+ },
29
+ data: {
30
+ type: Type.OBJECT,
31
+ properties: {
32
+ type: { type: Type.STRING, description: "The LayerType enum value." },
33
+ label: { type: Type.STRING },
34
+ params: { type: Type.OBJECT, description: "Hyperparameters for the layer." }
35
+ },
36
+ required: ["type", "label"]
37
+ }
38
+ },
39
+ required: ["id", "type", "position", "data"]
40
+ }
41
+ },
42
+ edges: {
43
+ type: Type.ARRAY,
44
+ items: {
45
+ type: Type.OBJECT,
46
+ properties: {
47
+ id: { type: Type.STRING },
48
+ source: { type: Type.STRING },
49
+ target: { type: Type.STRING }
50
+ },
51
+ required: ["source", "target"]
52
+ }
53
+ }
54
+ },
55
+ required: ["nodes", "edges"]
56
+ };
57
+
58
+ /**
59
+ * Helper: Cleans a JSON string that might be wrapped in markdown backticks.
60
+ */
61
+ const cleanJsonString = (str: string): string => {
62
+ return str.replace(/```json/g, '').replace(/```/g, '').trim();
63
+ };
64
+
65
  // Key Management
66
  const getEnvKey = () => {
67
  // In Development: Returns key from .env file if available
 
92
  return new GoogleGenAI({ apiKey: key });
93
  };
94
 
95
+ export const DEFAULT_MODEL = 'gemini-2.5-flash';
96
 
97
  export type AgentStatus = 'idle' | 'architect' | 'critic' | 'refiner' | 'debugger' | 'patcher' | 'complete' | 'error';
98
 
 
159
  * Generates a polished, professional prompt using the AI.
160
  * It takes the raw hardcoded spec and asks the AI to format it perfectly for a coding LLM.
161
  */
162
+ export const generateRefinedPrompt = async (nodes: Node<NodeData>[], edges: Edge[], model: string = DEFAULT_MODEL): Promise<string> => {
163
  const ai = getAiClient();
164
  const rawSpec = buildRawPrompt(nodes, edges);
165
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  try {
167
  const response = await ai.models.generateContent({
168
+ model: model,
169
+ contents: `Input Raw Specification:\n${rawSpec}`,
170
+ config: {
171
+ systemInstruction: `You are an expert AI Prompt Engineer specializing in Deep Learning Architecture.
172
+ Your goal is to transform a raw graph specification into a high-fidelity, professional prompt for a coding LLM.
173
+ You have access to Google Search to look up the latest PyTorch best practices or SOTA implementation details to include in the prompt.
174
+
175
+ Instructions:
176
+ 1. Start with: "You are a world-class Deep Learning Engineer. Implement the following PyTorch model with precision:"
177
+ 2. Section "Architecture Details": List all layers with their IDs, Types, and specific hyperparameters.
178
+ 3. Section "Data Flow & Connectivity": Describe the forward pass step-by-step.
179
+ - Explicitly mention skip connections, residual additions, and concatenation points.
180
+ - For multi-input layers (like CrossAttention or Add), specify exactly which node IDs provide the inputs.
181
+ 4. Section "Implementation Requirements":
182
+ - Use idiomatic PyTorch (nn.Module).
183
+ - Ensure input/output shapes are documented in comments.
184
+ - Include proper initialization (e.g., Kaiming or Xavier) where appropriate.
185
+ - Handle potential shape mismatches (e.g., adding a Flatten layer before Linear if needed).
186
+ 5. If custom code is provided in the spec, ensure it is integrated correctly.
187
+ 6. Do NOT write the code yourself. Write the PROMPT that will guide another AI to write the code.
188
+ 7. Maintain a highly technical, rigorous, and clear tone.`,
189
+ tools: [{ googleSearch: {} }]
190
+ }
191
  });
192
  return response.text.trim();
193
  } catch (error) {
194
  console.error("Prompt refinement failed:", error);
195
+ throw error;
196
  }
197
  };
198
 
199
+ export const validateArchitecture = async (nodes: Node<NodeData>[], edges: Edge[], model: string = DEFAULT_MODEL): Promise<string> => {
200
  const ai = getAiClient();
201
  const graphRepresentation = {
202
  nodes: nodes.map(n => ({
 
210
  }))
211
  };
212
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
  try {
214
  const response = await ai.models.generateContent({
215
+ model: model,
216
+ contents: `Graph to Analyze: ${JSON.stringify(graphRepresentation)}`,
217
+ config: {
218
+ systemInstruction: `You are a Senior Deep Learning Validator. Your role is to find structural and logical flaws in neural network architectures.
219
+
220
+ Checklist:
221
+ 1. Dimensional Consistency: Do Conv2D/Conv3D layers have appropriate pooling or flattening before Linear layers?
222
+ 2. Connectivity: Are there any orphaned nodes? Is there a path from Input to Output?
223
+ 3. Layer Logic: Are activations placed correctly? (e.g., no ReLU after a Softmax).
224
+ 4. Merge Operations: Do Add/Concat layers have at least 2 inputs? Do they have compatible shapes?
225
+ 5. GenAI Patterns: Does Attention have Query/Key/Value paths? Do VLM projections match the LLM backbone dimensions?
226
+
227
+ Output Format:
228
+ - If valid: "Architecture is valid."
229
+ - If invalid: A bulleted list of "CRITICAL ERRORS" and "SUGGESTED FIXES".`
230
+ }
231
  });
232
  return response.text;
233
  } catch (error) {
 
238
  /**
239
  * Gets AI suggestions for improving the architecture.
240
  */
241
+ export const getArchitectureSuggestions = async (nodes: Node<NodeData>[], edges: Edge[], model: string = DEFAULT_MODEL): Promise<string> => {
242
  const ai = getAiClient();
243
  const graphRepresentation = {
244
  nodes: nodes.map(n => ({
 
253
  }))
254
  };
255
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
  try {
257
  const response = await ai.models.generateContent({
258
+ model: model,
259
+ contents: `Current Graph Architecture: ${JSON.stringify(graphRepresentation, null, 2)}`,
260
+ config: {
261
+ systemInstruction: `You are a Senior Deep Learning Architect. Your goal is to provide elite-level optimization suggestions for neural network architectures.
262
+ You have access to Google Search to research the latest SOTA (State of the Art) components and research papers.
263
+
264
+ Focus Areas:
265
+ 1. Efficiency: Reduce parameter count or computational complexity without sacrificing accuracy.
266
+ 2. Modernity: Suggest state-of-the-art components (e.g., FlashAttention, SwiGLU, RMSNorm).
267
+ 3. Robustness: Identify risks of vanishing/exploding gradients and suggest residuals or normalization.
268
+ 4. Scalability: Suggest ways to make the model more modular or scalable.
269
+
270
+ Format:
271
+ Provide 3-5 concise, bulleted suggestions. Each suggestion should include a "Why" (the technical benefit).`,
272
+ tools: [{ googleSearch: {} }]
273
+ }
274
  });
275
  return response.text;
276
  } catch (error) {
 
284
  export const implementArchitectureSuggestions = async (
285
  nodes: Node<NodeData>[],
286
  edges: Edge[],
287
+ suggestions: string,
288
+ model: string = DEFAULT_MODEL
289
  ): Promise<{ nodes: any[], edges: any[] }> => {
290
  const ai = getAiClient();
291
 
 
303
  }))
304
  };
305
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
306
  try {
307
  const response = await ai.models.generateContent({
308
+ model: model,
309
+ contents: `Current Graph Architecture: ${JSON.stringify(graphRepresentation)}\nSuggestions to Implement: "${suggestions}"`,
310
+ config: {
311
+ systemInstruction: `You are a Senior Implementation Engineer specializing in Deep Learning.
312
+ Your task is to modify the provided neural network graph JSON to incorporate specific architectural improvements.
313
+
314
+ Rules:
315
+ 1. Maintain the relative layout (x, y positions). If adding nodes, place them logically between existing ones.
316
+ 2. Use ONLY valid LayerTypes from the schema.
317
+ 3. Ensure all new nodes have unique IDs.
318
+ 4. Ensure the resulting graph is fully connected and logically sound.
319
+ 5. Preserve existing node parameters unless the suggestion explicitly requires changing them.`,
320
+ responseMimeType: "application/json",
321
+ responseSchema: graphResponseSchema
322
+ }
323
  });
324
  const result = JSON.parse(response.text.trim());
325
  return sanitizeGraph(result);
 
334
  * Checks if 'type' is valid. If not, converts it to a CUSTOM layer to prevent crashes.
335
  */
336
  const sanitizeGraph = (graphJson: { nodes: any[], edges: any[] }) => {
337
+ if (!graphJson) return graphJson;
338
 
339
+ if (graphJson.nodes) {
340
+ graphJson.nodes = graphJson.nodes.map(node => {
341
+ // AI might return type in data.type or top-level type
342
+ let rawType = node.data?.type || node.type || 'Identity';
343
+
344
+ // Ensure rawType is a string
345
+ if (typeof rawType !== 'string') rawType = 'Identity';
346
+
347
+ // Check if this type exists in our known definitions
348
+ const isValid = Object.values(LayerType).includes(rawType as LayerType) && LAYER_DEFINITIONS[rawType as LayerType];
349
+
350
+ if (!isValid) {
351
+ console.warn(`Sanitizing unknown layer type: ${rawType}. Converting to CustomLayer.`);
352
+ return {
353
+ ...node,
354
+ type: 'custom', // ReactFlow type
355
+ data: {
356
+ ...node.data,
357
+ type: LayerType.CUSTOM,
358
+ label: node.data?.label || rawType,
359
+ params: {
360
+ ...(node.data?.params || {}),
361
+ class_name: rawType, // Store original type name here
362
+ args: JSON.stringify(node.data?.params || {}).slice(0, 100) // Rough preserve of args
363
+ }
364
+ }
365
+ };
366
+ }
367
+
368
+ // Ensure standard structure for valid nodes
369
  return {
370
  ...node,
371
+ type: 'custom',
372
  data: {
373
  ...node.data,
374
+ type: rawType,
375
+ params: node.data?.params || {}
 
 
 
 
 
376
  }
377
  };
378
+ });
379
+ }
380
+
381
+ if (graphJson.edges) {
382
+ const seenIds = new Set<string>();
383
+ graphJson.edges = graphJson.edges.map((edge, idx) => {
384
+ // AI might return edge without ID or with duplicate ID
385
+ let baseId = edge.id || `e-${edge.source}-${edge.target}`;
386
+
387
+ // Ensure uniqueness within this graph
388
+ let uniqueId = baseId;
389
+ let counter = 1;
390
+ while (seenIds.has(uniqueId)) {
391
+ uniqueId = `${baseId}-${idx}-${counter++}`;
392
  }
393
+ seenIds.add(uniqueId);
394
+
395
+ return {
396
+ ...edge,
397
+ id: uniqueId,
398
+ animated: true,
399
+ style: { stroke: '#94a3b8' }
400
+ };
401
+ });
402
+ }
403
 
404
  return graphJson;
405
  };
 
414
  export const generateGraphWithAgents = async (
415
  userPrompt: string,
416
  currentNodes: Node<NodeData>[] = [],
417
+ onStatusUpdate: (status: AgentStatus, log: string) => void,
418
+ model: string = DEFAULT_MODEL
419
  ): Promise<{ nodes: any[], edges: any[] } | null> => {
420
  const ai = getAiClient();
421
 
 
432
  ? `Current Graph Context: ${JSON.stringify(currentNodes.map(n => ({ id: n.id, type: n.data.type, label: n.data.label })))}`
433
  : "Starting from scratch.";
434
 
435
+ const architectSystemInstruction = `You are a Senior Neural Network Architect.
436
+ Your task is to design a high-quality neural network graph based on user requirements.
437
+
438
+ Design Principles:
439
+ 1. Hierarchical Flow: Arrange nodes from top (Input) to bottom (Output) with y-offsets of +150.
440
+ 2. Modern Patterns: Use residuals/skip connections for deep networks. Prefer LayerNorm/RMSNorm for Transformers.
441
+ 3. Spatial Clarity: If a layer has multiple branches, spread them horizontally (x-offsets).
442
+ 4. Schema Compliance: Use ONLY valid LayerTypes and parameter names provided in the schema.
443
+
444
+ Output Format:
445
+ Return a JSON object with:
446
+ - "reasoning": A brief string explaining your architectural choices.
447
+ - "nodes": Array of { id, type='custom', position:{x,y}, data:{ type, label, params:{} } }
448
+ - "edges": Array of { id, source, target }`;
 
 
 
 
 
449
 
450
  let draftJsonStr = "";
451
+ let architectResponse;
452
  try {
453
+ architectResponse = await ai.models.generateContent({
454
+ model: model,
455
+ contents: `User Request: "${userPrompt}"\nContext: ${context}\nSchema: ${schemaStr}`,
456
+ config: {
457
+ systemInstruction: architectSystemInstruction,
458
+ // REMOVED responseMimeType and responseSchema as they conflict with tools (googleSearch/urlContext)
459
+ tools: [{ googleSearch: {} }, { urlContext: {} }]
460
+ }
461
  });
462
+ draftJsonStr = cleanJsonString(architectResponse.text);
463
  } catch (e) {
464
  throw e;
465
  }
 
467
  // --- Step 2: Critic ---
468
  onStatusUpdate('critic', 'Critic is reviewing architecture for flaws...');
469
 
470
+ const criticSystemInstruction = `You are a Lead Architecture Reviewer.
471
+ Critique the provided neural network draft for technical soundness, efficiency, and adherence to the user's request.
472
+ You have access to Google Search to verify SOTA (State of the Art) practices if needed.
473
+
474
+ Check for:
475
+ - Missing critical components (e.g., no activation after Conv, no pooling before FC).
476
+ - Bottlenecks (e.g., massive dimensionality jumps).
477
+ - Redundancy or inefficient paths.
478
+ - Logical errors in connectivity.
479
+
480
+ If the architecture is excellent, say "No changes needed". Otherwise, provide a concise list of required improvements.`;
 
 
 
 
481
 
482
  let critique = "";
483
+ let criticResponse;
484
  try {
485
+ criticResponse = await ai.models.generateContent({
486
+ model: model,
487
+ // Shared Context: Pass the architect's response as a previous turn to be token efficient
488
+ contents: [
489
+ { role: 'user', parts: [{ text: `User Request: "${userPrompt}"\nSchema: ${schemaStr}` }] },
490
+ { role: 'model', parts: [{ text: draftJsonStr }] },
491
+ { role: 'user', parts: [{ text: "Please critique this architecture." }] }
492
+ ],
493
+ config: {
494
+ systemInstruction: criticSystemInstruction,
495
+ tools: [{ googleSearch: {} }]
496
+ }
497
  });
498
+ critique = criticResponse.text.trim();
499
  } catch (e) {
500
  console.warn("Critic agent failed, proceeding with draft.");
501
  critique = "No critique available.";
 
504
  // --- Step 3: Refiner ---
505
  onStatusUpdate('refiner', 'Refiner is applying fixes and finalizing...');
506
 
507
+ const refinerSystemInstruction = `You are a Lead Implementation Engineer.
508
+ Finalize the neural network JSON by applying the Critic's feedback to the Architect's draft.
509
+
510
+ Requirements:
511
+ 1. Strictly follow the JSON schema: { "nodes": [...], "edges": [...] }.
512
+ 2. Ensure all node IDs are unique.
513
+ 3. Ensure all LayerTypes match the enum exactly.
514
+ 4. Fix any shape or logic errors identified in the critique.
515
+ 5. Maintain a clean, readable visual layout.
516
+
517
+ Return ONLY the final JSON object.`;
 
 
 
 
 
518
 
519
  try {
520
  const response = await ai.models.generateContent({
521
+ model: model,
522
+ // Shared Context: Pass the entire history for maximum consistency
523
+ contents: [
524
+ { role: 'user', parts: [{ text: `User Request: "${userPrompt}"\nSchema: ${schemaStr}` }] },
525
+ { role: 'model', parts: [{ text: draftJsonStr }] },
526
+ { role: 'user', parts: [{ text: `Critique: "${critique}"\n\nApply the critique and provide the final JSON.` }] }
527
+ ],
528
+ config: {
529
+ systemInstruction: refinerSystemInstruction,
530
+ // REMOVED responseMimeType and responseSchema as they conflict with previous tool use in history
531
+ }
532
  });
533
+ const finalJson = JSON.parse(cleanJsonString(response.text));
534
  onStatusUpdate('complete', 'Architecture built successfully!');
535
 
 
536
  return sanitizeGraph(finalJson);
537
  } catch (e) {
538
  throw new Error("Refiner agent failed to parse final JSON.");
 
545
  */
546
  export const generateCodeWithAgents = async (
547
  promptText: string,
548
+ onStatusUpdate: (status: AgentStatus, log: string) => void,
549
+ model: string = DEFAULT_MODEL
550
  ): Promise<string> => {
551
  const ai = getAiClient();
552
 
553
+ // --- Step 1: Coder ---
554
  onStatusUpdate('architect', 'Coder Agent is writing initial PyTorch implementation...');
555
 
556
+ const coderSystemInstruction = `You are a Senior Deep Learning Engineer.
557
+ Your task is to write clean, modular, and production-ready PyTorch code.
558
+
559
+ Coding Standards:
560
+ 1. Use nn.Module for the main model.
561
+ 2. Include docstrings for the class and forward method.
562
+ 3. Add comments indicating the expected tensor shapes at each major step.
563
+ 4. Use descriptive variable names.
564
+ 5. Include a robust 'if __name__ == "__main__":' block that instantiates the model and runs a dummy forward pass with torch.randn.
565
+ 6. Handle custom layer logic or external imports if specified.
566
+
567
+ Return ONLY the Python code. No markdown backticks.`;
 
568
 
569
  let draftCode = "";
570
  try {
571
  const response = await ai.models.generateContent({
572
+ model: model,
573
+ contents: `Request: "${promptText}"`,
574
+ config: {
575
+ systemInstruction: coderSystemInstruction,
576
+ tools: [{ googleSearch: {} }, { urlContext: {} }]
577
+ }
578
  });
579
  draftCode = response.text.trim().replace(/```python/g, '').replace(/```/g, '');
580
  } catch(e) {
581
  throw e;
582
  }
583
 
584
+ // --- Step 2: Reviewer ---
585
  onStatusUpdate('critic', 'Reviewer Agent is analyzing code for bugs and optimization...');
586
+
587
+ const reviewerSystemInstruction = `You are a Senior Code Reviewer.
588
+ Analyze the provided PyTorch code for:
589
+ - Syntax errors or missing imports.
590
+ - Logical bugs (e.g., wrong dimension in cat/stack).
591
+ - Missing super().__init__() calls.
592
+ - Inefficient implementations.
593
+ - Adherence to the original architecture request.
594
+ You have access to Google Search to verify PyTorch documentation or SOTA implementation details.
595
+
596
+ Provide a concise, technical critique. If the code is perfect, say "No changes needed".`;
597
+
598
  let critique = "";
599
  try {
600
  const response = await ai.models.generateContent({
601
+ model: model,
602
+ contents: [
603
+ { role: 'user', parts: [{ text: `Original Request: "${promptText}"` }] },
604
+ { role: 'model', parts: [{ text: draftCode }] },
605
+ { role: 'user', parts: [{ text: "Please review this code." }] }
606
+ ],
607
+ config: {
608
+ systemInstruction: reviewerSystemInstruction,
609
+ tools: [{ googleSearch: {} }]
610
+ }
611
  });
612
  critique = response.text.trim();
613
  } catch(e) {
614
  critique = "No critique available.";
615
  }
616
 
617
+ // --- Step 3: Polisher ---
618
  onStatusUpdate('refiner', 'Polisher Agent is finalizing the codebase...');
619
+
620
+ const polisherSystemInstruction = `You are a Principal Software Engineer.
621
+ Refine the PyTorch code by incorporating the Reviewer's feedback.
622
+
623
+ Goals:
624
+ 1. Fix all bugs and style issues.
625
+ 2. Ensure the code is strictly valid Python.
626
+ 3. Maintain all docstrings and shape comments.
627
+ 4. Ensure the test block works perfectly.
628
+
629
+ Return ONLY the final Python code. No markdown backticks.`;
630
 
631
  try {
632
  const response = await ai.models.generateContent({
633
+ model: model,
634
+ contents: [
635
+ { role: 'user', parts: [{ text: `Original Request: "${promptText}"` }] },
636
+ { role: 'model', parts: [{ text: draftCode }] },
637
+ { role: 'user', parts: [{ text: `Critique: ${critique}\n\nApply the critique and provide the final polished code.` }] }
638
+ ],
639
+ config: { systemInstruction: polisherSystemInstruction }
640
  });
641
+ let finalCode = cleanJsonString(response.text);
642
  onStatusUpdate('complete', 'Code generation complete!');
643
  return finalCode;
644
  } catch(e) {
 
653
  nodes: Node<NodeData>[],
654
  edges: Edge[],
655
  errorMsg: string,
656
+ onStatusUpdate: (status: AgentStatus, log: string) => void,
657
+ model: string = DEFAULT_MODEL
658
  ): Promise<{ nodes: any[], edges: any[] } | null> => {
659
  const ai = getAiClient();
660
  const graphJson = JSON.stringify({
 
665
  // --- Step 1: Debugger ---
666
  onStatusUpdate('debugger', 'Debugger Agent is analyzing the error trace...');
667
 
 
 
 
 
 
 
 
 
 
 
668
  let debugAnalysis = "";
669
  try {
670
  const response = await ai.models.generateContent({
671
+ model: model,
672
+ contents: `Graph: ${graphJson}\nError Message: "${errorMsg}"`,
673
+ config: {
674
+ systemInstruction: `You are a Senior Systems Debugger specializing in Deep Learning.
675
+ Analyze the provided graph and error message to identify the root cause of the failure.
676
+ You have access to Google Search to look up specific PyTorch error traces or layer compatibility issues.
677
+
678
+ Focus on:
679
+ - Shape mismatches between connected layers.
680
+ - Incorrect parameter values.
681
+ - Missing required layers (e.g., Flatten before Linear).
682
+
683
+ Output a concise, technical explanation of the bug.`,
684
+ tools: [{ googleSearch: {} }]
685
+ }
686
  });
687
  debugAnalysis = response.text.trim();
688
  } catch (e) {
 
692
  // --- Step 2: Architect ---
693
  onStatusUpdate('architect', 'Architect Agent is planning the fix...');
694
 
 
 
 
 
 
 
 
 
 
 
 
 
 
695
  let fixPlan = "";
696
  try {
697
  const response = await ai.models.generateContent({
698
+ model: model,
699
+ contents: `Issue Analysis: ${debugAnalysis}`,
700
+ config: {
701
+ systemInstruction: `You are a Solution Architect.
702
+ Based on the debugger's analysis, propose a minimal and effective structural fix.
703
+
704
+ Instructions:
705
+ - Specify which nodes to add, remove, or modify.
706
+ - Specify which edges to reconnect.
707
+ - Ensure the fix doesn't introduce new errors.
708
+
709
+ Output the plan in clear, actionable steps.`
710
+ }
711
  });
712
  fixPlan = response.text.trim();
713
  } catch (e) {
714
  fixPlan = "Apply necessary structural corrections.";
715
  }
716
 
717
+ // --- Step 2.5: Critic (New Step) ---
718
+ onStatusUpdate('critic', 'Critic Agent is reviewing the fix plan...');
719
+
720
+ let fixCritique = "";
721
+ try {
722
+ const response = await ai.models.generateContent({
723
+ model: model,
724
+ contents: `Error: "${errorMsg}"\nAnalysis: ${debugAnalysis}\nProposed Fix: ${fixPlan}`,
725
+ config: {
726
+ systemInstruction: `You are a Senior Reviewer.
727
+ Review the proposed fix plan. Does it actually solve the root cause?
728
+ Does it introduce new shape mismatches?
729
+ Is it the most efficient way to fix the error?
730
+ You have access to Google Search to verify SOTA fixes for similar issues.
731
+
732
+ Output a brief critique or "Approved" if perfect.`,
733
+ tools: [{ googleSearch: {} }]
734
+ }
735
+ });
736
+ fixCritique = response.text.trim();
737
+ } catch (e) {
738
+ fixCritique = "Approved";
739
+ }
740
+
741
  // --- Step 3: Patcher ---
742
  onStatusUpdate('patcher', 'Patcher Agent is applying the fix to the graph...');
743
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
744
  try {
745
  const response = await ai.models.generateContent({
746
+ model: model,
747
+ contents: `Current Graph: ${graphJson}\nFix Plan: ${fixPlan}\nCritique: ${fixCritique}`,
748
+ config: {
749
+ systemInstruction: `You are a Senior Patcher.
750
+ Apply the proposed fix plan (considering the critique) to the graph JSON.
751
+
752
+ Requirements:
753
+ 1. Return a valid JSON object with "nodes" and "edges".
754
+ 2. Use ONLY valid LayerTypes.
755
+ 3. Maintain existing node IDs and positions where possible.
756
+ 4. Ensure all new nodes have unique IDs.
757
+
758
+ Return ONLY the final JSON object.`,
759
+ // REMOVED responseMimeType and responseSchema as they conflict with previous tool use in history
760
+ }
761
  });
762
+ const finalJson = JSON.parse(cleanJsonString(response.text));
763
  onStatusUpdate('complete', 'Fix applied successfully!');
764
  return sanitizeGraph(finalJson);
765
  } catch (e) {
766
  throw new Error("Patcher agent failed to generate valid JSON.");
767
  }
768
+ };