Spaces:
Running
Running
Update services/geminiService.ts
Browse files- services/geminiService.ts +404 -276
services/geminiService.ts
CHANGED
|
@@ -1,8 +1,67 @@
|
|
| 1 |
-
import { GoogleGenAI } from "@google/genai";
|
| 2 |
import { Node, Edge } from 'reactflow';
|
| 3 |
import { NodeData, LayerType } from '../types';
|
| 4 |
import { LAYER_DEFINITIONS } from '../constants';
|
| 5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
// Key Management
|
| 7 |
const getEnvKey = () => {
|
| 8 |
// In Development: Returns key from .env file if available
|
|
@@ -33,7 +92,7 @@ const getAiClient = () => {
|
|
| 33 |
return new GoogleGenAI({ apiKey: key });
|
| 34 |
};
|
| 35 |
|
| 36 |
-
const
|
| 37 |
|
| 38 |
export type AgentStatus = 'idle' | 'architect' | 'critic' | 'refiner' | 'debugger' | 'patcher' | 'complete' | 'error';
|
| 39 |
|
|
@@ -100,46 +159,44 @@ const buildRawPrompt = (nodes: Node<NodeData>[], edges: Edge[]): string => {
|
|
| 100 |
* Generates a polished, professional prompt using the AI.
|
| 101 |
* It takes the raw hardcoded spec and asks the AI to format it perfectly for a coding LLM.
|
| 102 |
*/
|
| 103 |
-
export const generateRefinedPrompt = async (nodes: Node<NodeData>[], edges: Edge[]): Promise<string> => {
|
| 104 |
const ai = getAiClient();
|
| 105 |
const rawSpec = buildRawPrompt(nodes, edges);
|
| 106 |
|
| 107 |
-
const systemPrompt = `
|
| 108 |
-
You are an expert AI Prompt Engineer for Deep Learning.
|
| 109 |
-
Your goal is to take a raw, technical neural network specification and rewrite it into a
|
| 110 |
-
perfect, professional, and detailed prompt that another AI (like a coding assistant) could use to write flawless PyTorch code.
|
| 111 |
-
|
| 112 |
-
Input Raw Specification:
|
| 113 |
-
${rawSpec}
|
| 114 |
-
|
| 115 |
-
Instructions:
|
| 116 |
-
1. Start the output with: "You are an expert Deep Learning Engineer. Please write a complete, runnable PyTorch model code for the following neural network architecture:"
|
| 117 |
-
2. Create a section "Architecture Specification".
|
| 118 |
-
3. List "1. Layers (Nodes)" cleanly. Include ID, Type, and Parameters.
|
| 119 |
-
4. List "2. Connectivity (Forward Pass Flow)" cleanly.
|
| 120 |
-
- Explicitly describe merge points (e.g. "Node X receives inputs from A and B. Handle this merge...").
|
| 121 |
-
- Note specific handling for complex layers like CrossAttention (needs Query + Key/Value) or SAM Decoders.
|
| 122 |
-
5. Create a section "Implementation Requirements" with standard PyTorch best practices (nn.Module, forward method, correct shapes).
|
| 123 |
-
6. If CUSTOM_CODE_DEFINITION or CUSTOM_IMPORTS are present, explicitly instruct the coder to include them verbatim or use them as reference.
|
| 124 |
-
7. Do NOT write the Python code yourself. Write the PROMPT that asks for the code.
|
| 125 |
-
8. Ensure the tone is technical and precise.
|
| 126 |
-
|
| 127 |
-
Return ONLY the generated prompt text.
|
| 128 |
-
`;
|
| 129 |
-
|
| 130 |
try {
|
| 131 |
const response = await ai.models.generateContent({
|
| 132 |
-
model:
|
| 133 |
-
contents:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
});
|
| 135 |
return response.text.trim();
|
| 136 |
} catch (error) {
|
| 137 |
console.error("Prompt refinement failed:", error);
|
| 138 |
-
throw error;
|
| 139 |
}
|
| 140 |
};
|
| 141 |
|
| 142 |
-
export const validateArchitecture = async (nodes: Node<NodeData>[], edges: Edge[]): Promise<string> => {
|
| 143 |
const ai = getAiClient();
|
| 144 |
const graphRepresentation = {
|
| 145 |
nodes: nodes.map(n => ({
|
|
@@ -153,24 +210,24 @@ export const validateArchitecture = async (nodes: Node<NodeData>[], edges: Edge[
|
|
| 153 |
}))
|
| 154 |
};
|
| 155 |
|
| 156 |
-
const prompt = `
|
| 157 |
-
Analyze this neural network architecture graph for validity.
|
| 158 |
-
Graph: ${JSON.stringify(graphRepresentation)}
|
| 159 |
-
|
| 160 |
-
Check for:
|
| 161 |
-
1. Shape mismatches (e.g., Conv2D output to Linear without Flatten).
|
| 162 |
-
2. Disconnected components.
|
| 163 |
-
3. Logical errors (e.g., MaxPool after Output).
|
| 164 |
-
4. Merge layer correctness (Concat/Add needs multiple inputs).
|
| 165 |
-
5. GenAI correctness (e.g., CrossAttention needs 2 inputs, VLM projection dims match).
|
| 166 |
-
|
| 167 |
-
Return a concise report. If valid, say "Architecture is valid.". If invalid, list specific errors and suggest fixes.
|
| 168 |
-
`;
|
| 169 |
-
|
| 170 |
try {
|
| 171 |
const response = await ai.models.generateContent({
|
| 172 |
-
model:
|
| 173 |
-
contents:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
});
|
| 175 |
return response.text;
|
| 176 |
} catch (error) {
|
|
@@ -181,7 +238,7 @@ export const validateArchitecture = async (nodes: Node<NodeData>[], edges: Edge[
|
|
| 181 |
/**
|
| 182 |
* Gets AI suggestions for improving the architecture.
|
| 183 |
*/
|
| 184 |
-
export const getArchitectureSuggestions = async (nodes: Node<NodeData>[], edges: Edge[]): Promise<string> => {
|
| 185 |
const ai = getAiClient();
|
| 186 |
const graphRepresentation = {
|
| 187 |
nodes: nodes.map(n => ({
|
|
@@ -196,26 +253,24 @@ export const getArchitectureSuggestions = async (nodes: Node<NodeData>[], edges:
|
|
| 196 |
}))
|
| 197 |
};
|
| 198 |
|
| 199 |
-
const prompt = `
|
| 200 |
-
You are a Senior Deep Learning Architect. Review the following neural network architecture graph.
|
| 201 |
-
|
| 202 |
-
Graph Structure:
|
| 203 |
-
${JSON.stringify(graphRepresentation, null, 2)}
|
| 204 |
-
|
| 205 |
-
Task: Provide 3 to 5 concrete, actionable suggestions to improve this model.
|
| 206 |
-
Focus on:
|
| 207 |
-
- Modern best practices (e.g., using LayerNorm vs BatchNorm in Transformers, SwiGLU vs ReLU).
|
| 208 |
-
- Architecture efficiency and parameter count optimization.
|
| 209 |
-
- Potential bottlenecks or vanishing gradient risks.
|
| 210 |
-
- Adding residuals or skip connections if the model is deep.
|
| 211 |
-
|
| 212 |
-
Format the output as a clean bulleted list. Keep it concise, professional and helpful.
|
| 213 |
-
`;
|
| 214 |
-
|
| 215 |
try {
|
| 216 |
const response = await ai.models.generateContent({
|
| 217 |
-
model:
|
| 218 |
-
contents:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 219 |
});
|
| 220 |
return response.text;
|
| 221 |
} catch (error) {
|
|
@@ -229,7 +284,8 @@ export const getArchitectureSuggestions = async (nodes: Node<NodeData>[], edges:
|
|
| 229 |
export const implementArchitectureSuggestions = async (
|
| 230 |
nodes: Node<NodeData>[],
|
| 231 |
edges: Edge[],
|
| 232 |
-
suggestions: string
|
|
|
|
| 233 |
): Promise<{ nodes: any[], edges: any[] }> => {
|
| 234 |
const ai = getAiClient();
|
| 235 |
|
|
@@ -247,29 +303,23 @@ export const implementArchitectureSuggestions = async (
|
|
| 247 |
}))
|
| 248 |
};
|
| 249 |
|
| 250 |
-
const prompt = `
|
| 251 |
-
You are a Senior Implementation Engineer.
|
| 252 |
-
Task: Apply the following architectural suggestions to the provided graph JSON.
|
| 253 |
-
|
| 254 |
-
Current Graph:
|
| 255 |
-
${JSON.stringify(graphRepresentation)}
|
| 256 |
-
|
| 257 |
-
Suggestions to Implement:
|
| 258 |
-
"${suggestions}"
|
| 259 |
-
|
| 260 |
-
Instructions:
|
| 261 |
-
1. Modify the nodes and edges to incorporate the suggestions.
|
| 262 |
-
2. Maintain the layout (x, y positions) as best as possible, offsetting new nodes if added.
|
| 263 |
-
3. Ensure all LayerTypes are valid from the standard schema.
|
| 264 |
-
4. Return the complete, updated JSON with "nodes" and "edges" arrays.
|
| 265 |
-
5. Return ONLY raw JSON.
|
| 266 |
-
`;
|
| 267 |
-
|
| 268 |
try {
|
| 269 |
const response = await ai.models.generateContent({
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 273 |
});
|
| 274 |
const result = JSON.parse(response.text.trim());
|
| 275 |
return sanitizeGraph(result);
|
|
@@ -284,47 +334,72 @@ export const implementArchitectureSuggestions = async (
|
|
| 284 |
* Checks if 'type' is valid. If not, converts it to a CUSTOM layer to prevent crashes.
|
| 285 |
*/
|
| 286 |
const sanitizeGraph = (graphJson: { nodes: any[], edges: any[] }) => {
|
| 287 |
-
if (!graphJson
|
| 288 |
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 301 |
return {
|
| 302 |
...node,
|
| 303 |
-
type: 'custom',
|
| 304 |
data: {
|
| 305 |
...node.data,
|
| 306 |
-
type:
|
| 307 |
-
|
| 308 |
-
params: {
|
| 309 |
-
...(node.data?.params || {}),
|
| 310 |
-
class_name: rawType, // Store original type name here
|
| 311 |
-
args: JSON.stringify(node.data?.params || {}).slice(0, 100) // Rough preserve of args
|
| 312 |
-
}
|
| 313 |
}
|
| 314 |
};
|
| 315 |
-
}
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 325 |
}
|
| 326 |
-
|
| 327 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 328 |
|
| 329 |
return graphJson;
|
| 330 |
};
|
|
@@ -339,7 +414,8 @@ const sanitizeGraph = (graphJson: { nodes: any[], edges: any[] }) => {
|
|
| 339 |
export const generateGraphWithAgents = async (
|
| 340 |
userPrompt: string,
|
| 341 |
currentNodes: Node<NodeData>[] = [],
|
| 342 |
-
onStatusUpdate: (status: AgentStatus, log: string) => void
|
|
|
|
| 343 |
): Promise<{ nodes: any[], edges: any[] } | null> => {
|
| 344 |
const ai = getAiClient();
|
| 345 |
|
|
@@ -356,34 +432,34 @@ export const generateGraphWithAgents = async (
|
|
| 356 |
? `Current Graph Context: ${JSON.stringify(currentNodes.map(n => ({ id: n.id, type: n.data.type, label: n.data.label })))}`
|
| 357 |
: "Starting from scratch.";
|
| 358 |
|
| 359 |
-
const
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
5. Connect edges logically.
|
| 374 |
-
6. If multi-input/output, arrange horizontally.
|
| 375 |
-
|
| 376 |
-
Return ONLY raw JSON.
|
| 377 |
-
`;
|
| 378 |
|
| 379 |
let draftJsonStr = "";
|
|
|
|
| 380 |
try {
|
| 381 |
-
|
| 382 |
-
model:
|
| 383 |
-
contents:
|
| 384 |
-
config: {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 385 |
});
|
| 386 |
-
draftJsonStr =
|
| 387 |
} catch (e) {
|
| 388 |
throw e;
|
| 389 |
}
|
|
@@ -391,29 +467,35 @@ export const generateGraphWithAgents = async (
|
|
| 391 |
// --- Step 2: Critic ---
|
| 392 |
onStatusUpdate('critic', 'Critic is reviewing architecture for flaws...');
|
| 393 |
|
| 394 |
-
const
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
- Compliance with user request
|
| 406 |
-
|
| 407 |
-
Output a concise paragraph describing specific improvements needed. If perfect, say "No changes needed".
|
| 408 |
-
`;
|
| 409 |
|
| 410 |
let critique = "";
|
|
|
|
| 411 |
try {
|
| 412 |
-
|
| 413 |
-
model:
|
| 414 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 415 |
});
|
| 416 |
-
critique =
|
| 417 |
} catch (e) {
|
| 418 |
console.warn("Critic agent failed, proceeding with draft.");
|
| 419 |
critique = "No critique available.";
|
|
@@ -422,33 +504,35 @@ export const generateGraphWithAgents = async (
|
|
| 422 |
// --- Step 3: Refiner ---
|
| 423 |
onStatusUpdate('refiner', 'Refiner is applying fixes and finalizing...');
|
| 424 |
|
| 425 |
-
const
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
4. Ensure parameter values match the schema types.
|
| 437 |
-
5. Ensure "type" in top level node object is always 'custom'.
|
| 438 |
-
|
| 439 |
-
Return ONLY the final JSON.
|
| 440 |
-
`;
|
| 441 |
|
| 442 |
try {
|
| 443 |
const response = await ai.models.generateContent({
|
| 444 |
-
model:
|
| 445 |
-
|
| 446 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 447 |
});
|
| 448 |
-
const finalJson = JSON.parse(response.text
|
| 449 |
onStatusUpdate('complete', 'Architecture built successfully!');
|
| 450 |
|
| 451 |
-
// SANITIZE: Prevent UI crashes by handling hallucinated types
|
| 452 |
return sanitizeGraph(finalJson);
|
| 453 |
} catch (e) {
|
| 454 |
throw new Error("Refiner agent failed to parse final JSON.");
|
|
@@ -461,80 +545,100 @@ export const generateGraphWithAgents = async (
|
|
| 461 |
*/
|
| 462 |
export const generateCodeWithAgents = async (
|
| 463 |
promptText: string,
|
| 464 |
-
onStatusUpdate: (status: AgentStatus, log: string) => void
|
|
|
|
| 465 |
): Promise<string> => {
|
| 466 |
const ai = getAiClient();
|
| 467 |
|
| 468 |
-
// --- Step 1: Coder
|
| 469 |
onStatusUpdate('architect', 'Coder Agent is writing initial PyTorch implementation...');
|
| 470 |
|
| 471 |
-
const
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
`;
|
| 484 |
|
| 485 |
let draftCode = "";
|
| 486 |
try {
|
| 487 |
const response = await ai.models.generateContent({
|
| 488 |
-
model:
|
| 489 |
-
contents:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 490 |
});
|
| 491 |
draftCode = response.text.trim().replace(/```python/g, '').replace(/```/g, '');
|
| 492 |
} catch(e) {
|
| 493 |
throw e;
|
| 494 |
}
|
| 495 |
|
| 496 |
-
// --- Step 2: Reviewer
|
| 497 |
onStatusUpdate('critic', 'Reviewer Agent is analyzing code for bugs and optimization...');
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
|
| 508 |
-
|
|
|
|
| 509 |
let critique = "";
|
| 510 |
try {
|
| 511 |
const response = await ai.models.generateContent({
|
| 512 |
-
model:
|
| 513 |
-
contents:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 514 |
});
|
| 515 |
critique = response.text.trim();
|
| 516 |
} catch(e) {
|
| 517 |
critique = "No critique available.";
|
| 518 |
}
|
| 519 |
|
| 520 |
-
// --- Step 3: Polisher
|
| 521 |
onStatusUpdate('refiner', 'Polisher Agent is finalizing the codebase...');
|
| 522 |
-
|
| 523 |
-
|
| 524 |
-
|
| 525 |
-
|
| 526 |
-
|
| 527 |
-
|
| 528 |
-
|
| 529 |
-
|
| 530 |
-
|
|
|
|
|
|
|
| 531 |
|
| 532 |
try {
|
| 533 |
const response = await ai.models.generateContent({
|
| 534 |
-
model:
|
| 535 |
-
contents:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 536 |
});
|
| 537 |
-
let finalCode = response.text
|
| 538 |
onStatusUpdate('complete', 'Code generation complete!');
|
| 539 |
return finalCode;
|
| 540 |
} catch(e) {
|
|
@@ -549,7 +653,8 @@ export const fixArchitectureErrors = async (
|
|
| 549 |
nodes: Node<NodeData>[],
|
| 550 |
edges: Edge[],
|
| 551 |
errorMsg: string,
|
| 552 |
-
onStatusUpdate: (status: AgentStatus, log: string) => void
|
|
|
|
| 553 |
): Promise<{ nodes: any[], edges: any[] } | null> => {
|
| 554 |
const ai = getAiClient();
|
| 555 |
const graphJson = JSON.stringify({
|
|
@@ -560,21 +665,24 @@ export const fixArchitectureErrors = async (
|
|
| 560 |
// --- Step 1: Debugger ---
|
| 561 |
onStatusUpdate('debugger', 'Debugger Agent is analyzing the error trace...');
|
| 562 |
|
| 563 |
-
const debuggerPrompt = `
|
| 564 |
-
Role: Senior Systems Debugger.
|
| 565 |
-
Task: Analyze the architecture graph and the reported error to pinpoint the root cause.
|
| 566 |
-
|
| 567 |
-
Graph: ${graphJson}
|
| 568 |
-
Error Message: "${errorMsg}"
|
| 569 |
-
|
| 570 |
-
Output a technical analysis of exactly what is wrong (e.g. "Node A connects to Node B but shapes [X] and [Y] are incompatible").
|
| 571 |
-
`;
|
| 572 |
-
|
| 573 |
let debugAnalysis = "";
|
| 574 |
try {
|
| 575 |
const response = await ai.models.generateContent({
|
| 576 |
-
model:
|
| 577 |
-
contents:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 578 |
});
|
| 579 |
debugAnalysis = response.text.trim();
|
| 580 |
} catch (e) {
|
|
@@ -584,57 +692,77 @@ export const fixArchitectureErrors = async (
|
|
| 584 |
// --- Step 2: Architect ---
|
| 585 |
onStatusUpdate('architect', 'Architect Agent is planning the fix...');
|
| 586 |
|
| 587 |
-
const architectPrompt = `
|
| 588 |
-
Role: Solution Architect.
|
| 589 |
-
Task: Propose a specific fix for the identified issue.
|
| 590 |
-
|
| 591 |
-
Issue Analysis: ${debugAnalysis}
|
| 592 |
-
|
| 593 |
-
Instructions:
|
| 594 |
-
- Determine if nodes need to be added (e.g. Flatten, Reshape), removed, or reconnected.
|
| 595 |
-
- Determine if parameters need changing.
|
| 596 |
-
|
| 597 |
-
Output the plan in clear steps.
|
| 598 |
-
`;
|
| 599 |
-
|
| 600 |
let fixPlan = "";
|
| 601 |
try {
|
| 602 |
const response = await ai.models.generateContent({
|
| 603 |
-
model:
|
| 604 |
-
contents:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 605 |
});
|
| 606 |
fixPlan = response.text.trim();
|
| 607 |
} catch (e) {
|
| 608 |
fixPlan = "Apply necessary structural corrections.";
|
| 609 |
}
|
| 610 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 611 |
// --- Step 3: Patcher ---
|
| 612 |
onStatusUpdate('patcher', 'Patcher Agent is applying the fix to the graph...');
|
| 613 |
|
| 614 |
-
const patcherPrompt = `
|
| 615 |
-
Role: DevOps Engineer.
|
| 616 |
-
Task: Apply the fix to the graph JSON.
|
| 617 |
-
|
| 618 |
-
Current Graph: ${graphJson}
|
| 619 |
-
Fix Plan: ${fixPlan}
|
| 620 |
-
|
| 621 |
-
Requirements:
|
| 622 |
-
1. Return the complete, valid JSON with "nodes" and "edges".
|
| 623 |
-
2. Maintain existing node positions where possible, offset new nodes if added.
|
| 624 |
-
3. Ensure all LayerTypes are valid.
|
| 625 |
-
4. Return ONLY raw JSON.
|
| 626 |
-
`;
|
| 627 |
-
|
| 628 |
try {
|
| 629 |
const response = await ai.models.generateContent({
|
| 630 |
-
model:
|
| 631 |
-
contents:
|
| 632 |
-
config: {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 633 |
});
|
| 634 |
-
const finalJson = JSON.parse(response.text
|
| 635 |
onStatusUpdate('complete', 'Fix applied successfully!');
|
| 636 |
return sanitizeGraph(finalJson);
|
| 637 |
} catch (e) {
|
| 638 |
throw new Error("Patcher agent failed to generate valid JSON.");
|
| 639 |
}
|
| 640 |
-
};
|
|
|
|
| 1 |
+
import { GoogleGenAI, Type } from "@google/genai";
|
| 2 |
import { Node, Edge } from 'reactflow';
|
| 3 |
import { NodeData, LayerType } from '../types';
|
| 4 |
import { LAYER_DEFINITIONS } from '../constants';
|
| 5 |
|
| 6 |
+
// --- SCHEMAS ---
|
| 7 |
+
const graphResponseSchema = {
|
| 8 |
+
type: Type.OBJECT,
|
| 9 |
+
properties: {
|
| 10 |
+
reasoning: {
|
| 11 |
+
type: Type.STRING,
|
| 12 |
+
description: "Brief explanation of architectural choices or fixes applied."
|
| 13 |
+
},
|
| 14 |
+
nodes: {
|
| 15 |
+
type: Type.ARRAY,
|
| 16 |
+
items: {
|
| 17 |
+
type: Type.OBJECT,
|
| 18 |
+
properties: {
|
| 19 |
+
id: { type: Type.STRING },
|
| 20 |
+
type: { type: Type.STRING, description: "Must be 'custom' for ReactFlow compatibility." },
|
| 21 |
+
position: {
|
| 22 |
+
type: Type.OBJECT,
|
| 23 |
+
properties: {
|
| 24 |
+
x: { type: Type.NUMBER },
|
| 25 |
+
y: { type: Type.NUMBER }
|
| 26 |
+
},
|
| 27 |
+
required: ["x", "y"]
|
| 28 |
+
},
|
| 29 |
+
data: {
|
| 30 |
+
type: Type.OBJECT,
|
| 31 |
+
properties: {
|
| 32 |
+
type: { type: Type.STRING, description: "The LayerType enum value." },
|
| 33 |
+
label: { type: Type.STRING },
|
| 34 |
+
params: { type: Type.OBJECT, description: "Hyperparameters for the layer." }
|
| 35 |
+
},
|
| 36 |
+
required: ["type", "label"]
|
| 37 |
+
}
|
| 38 |
+
},
|
| 39 |
+
required: ["id", "type", "position", "data"]
|
| 40 |
+
}
|
| 41 |
+
},
|
| 42 |
+
edges: {
|
| 43 |
+
type: Type.ARRAY,
|
| 44 |
+
items: {
|
| 45 |
+
type: Type.OBJECT,
|
| 46 |
+
properties: {
|
| 47 |
+
id: { type: Type.STRING },
|
| 48 |
+
source: { type: Type.STRING },
|
| 49 |
+
target: { type: Type.STRING }
|
| 50 |
+
},
|
| 51 |
+
required: ["source", "target"]
|
| 52 |
+
}
|
| 53 |
+
}
|
| 54 |
+
},
|
| 55 |
+
required: ["nodes", "edges"]
|
| 56 |
+
};
|
| 57 |
+
|
| 58 |
+
/**
|
| 59 |
+
* Helper: Cleans a JSON string that might be wrapped in markdown backticks.
|
| 60 |
+
*/
|
| 61 |
+
const cleanJsonString = (str: string): string => {
|
| 62 |
+
return str.replace(/```json/g, '').replace(/```/g, '').trim();
|
| 63 |
+
};
|
| 64 |
+
|
| 65 |
// Key Management
|
| 66 |
const getEnvKey = () => {
|
| 67 |
// In Development: Returns key from .env file if available
|
|
|
|
| 92 |
return new GoogleGenAI({ apiKey: key });
|
| 93 |
};
|
| 94 |
|
| 95 |
+
export const DEFAULT_MODEL = 'gemini-2.5-flash';
|
| 96 |
|
| 97 |
export type AgentStatus = 'idle' | 'architect' | 'critic' | 'refiner' | 'debugger' | 'patcher' | 'complete' | 'error';
|
| 98 |
|
|
|
|
| 159 |
* Generates a polished, professional prompt using the AI.
|
| 160 |
* It takes the raw hardcoded spec and asks the AI to format it perfectly for a coding LLM.
|
| 161 |
*/
|
| 162 |
+
export const generateRefinedPrompt = async (nodes: Node<NodeData>[], edges: Edge[], model: string = DEFAULT_MODEL): Promise<string> => {
|
| 163 |
const ai = getAiClient();
|
| 164 |
const rawSpec = buildRawPrompt(nodes, edges);
|
| 165 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
try {
|
| 167 |
const response = await ai.models.generateContent({
|
| 168 |
+
model: model,
|
| 169 |
+
contents: `Input Raw Specification:\n${rawSpec}`,
|
| 170 |
+
config: {
|
| 171 |
+
systemInstruction: `You are an expert AI Prompt Engineer specializing in Deep Learning Architecture.
|
| 172 |
+
Your goal is to transform a raw graph specification into a high-fidelity, professional prompt for a coding LLM.
|
| 173 |
+
You have access to Google Search to look up the latest PyTorch best practices or SOTA implementation details to include in the prompt.
|
| 174 |
+
|
| 175 |
+
Instructions:
|
| 176 |
+
1. Start with: "You are a world-class Deep Learning Engineer. Implement the following PyTorch model with precision:"
|
| 177 |
+
2. Section "Architecture Details": List all layers with their IDs, Types, and specific hyperparameters.
|
| 178 |
+
3. Section "Data Flow & Connectivity": Describe the forward pass step-by-step.
|
| 179 |
+
- Explicitly mention skip connections, residual additions, and concatenation points.
|
| 180 |
+
- For multi-input layers (like CrossAttention or Add), specify exactly which node IDs provide the inputs.
|
| 181 |
+
4. Section "Implementation Requirements":
|
| 182 |
+
- Use idiomatic PyTorch (nn.Module).
|
| 183 |
+
- Ensure input/output shapes are documented in comments.
|
| 184 |
+
- Include proper initialization (e.g., Kaiming or Xavier) where appropriate.
|
| 185 |
+
- Handle potential shape mismatches (e.g., adding a Flatten layer before Linear if needed).
|
| 186 |
+
5. If custom code is provided in the spec, ensure it is integrated correctly.
|
| 187 |
+
6. Do NOT write the code yourself. Write the PROMPT that will guide another AI to write the code.
|
| 188 |
+
7. Maintain a highly technical, rigorous, and clear tone.`,
|
| 189 |
+
tools: [{ googleSearch: {} }]
|
| 190 |
+
}
|
| 191 |
});
|
| 192 |
return response.text.trim();
|
| 193 |
} catch (error) {
|
| 194 |
console.error("Prompt refinement failed:", error);
|
| 195 |
+
throw error;
|
| 196 |
}
|
| 197 |
};
|
| 198 |
|
| 199 |
+
export const validateArchitecture = async (nodes: Node<NodeData>[], edges: Edge[], model: string = DEFAULT_MODEL): Promise<string> => {
|
| 200 |
const ai = getAiClient();
|
| 201 |
const graphRepresentation = {
|
| 202 |
nodes: nodes.map(n => ({
|
|
|
|
| 210 |
}))
|
| 211 |
};
|
| 212 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
try {
|
| 214 |
const response = await ai.models.generateContent({
|
| 215 |
+
model: model,
|
| 216 |
+
contents: `Graph to Analyze: ${JSON.stringify(graphRepresentation)}`,
|
| 217 |
+
config: {
|
| 218 |
+
systemInstruction: `You are a Senior Deep Learning Validator. Your role is to find structural and logical flaws in neural network architectures.
|
| 219 |
+
|
| 220 |
+
Checklist:
|
| 221 |
+
1. Dimensional Consistency: Do Conv2D/Conv3D layers have appropriate pooling or flattening before Linear layers?
|
| 222 |
+
2. Connectivity: Are there any orphaned nodes? Is there a path from Input to Output?
|
| 223 |
+
3. Layer Logic: Are activations placed correctly? (e.g., no ReLU after a Softmax).
|
| 224 |
+
4. Merge Operations: Do Add/Concat layers have at least 2 inputs? Do they have compatible shapes?
|
| 225 |
+
5. GenAI Patterns: Does Attention have Query/Key/Value paths? Do VLM projections match the LLM backbone dimensions?
|
| 226 |
+
|
| 227 |
+
Output Format:
|
| 228 |
+
- If valid: "Architecture is valid."
|
| 229 |
+
- If invalid: A bulleted list of "CRITICAL ERRORS" and "SUGGESTED FIXES".`
|
| 230 |
+
}
|
| 231 |
});
|
| 232 |
return response.text;
|
| 233 |
} catch (error) {
|
|
|
|
| 238 |
/**
|
| 239 |
* Gets AI suggestions for improving the architecture.
|
| 240 |
*/
|
| 241 |
+
export const getArchitectureSuggestions = async (nodes: Node<NodeData>[], edges: Edge[], model: string = DEFAULT_MODEL): Promise<string> => {
|
| 242 |
const ai = getAiClient();
|
| 243 |
const graphRepresentation = {
|
| 244 |
nodes: nodes.map(n => ({
|
|
|
|
| 253 |
}))
|
| 254 |
};
|
| 255 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 256 |
try {
|
| 257 |
const response = await ai.models.generateContent({
|
| 258 |
+
model: model,
|
| 259 |
+
contents: `Current Graph Architecture: ${JSON.stringify(graphRepresentation, null, 2)}`,
|
| 260 |
+
config: {
|
| 261 |
+
systemInstruction: `You are a Senior Deep Learning Architect. Your goal is to provide elite-level optimization suggestions for neural network architectures.
|
| 262 |
+
You have access to Google Search to research the latest SOTA (State of the Art) components and research papers.
|
| 263 |
+
|
| 264 |
+
Focus Areas:
|
| 265 |
+
1. Efficiency: Reduce parameter count or computational complexity without sacrificing accuracy.
|
| 266 |
+
2. Modernity: Suggest state-of-the-art components (e.g., FlashAttention, SwiGLU, RMSNorm).
|
| 267 |
+
3. Robustness: Identify risks of vanishing/exploding gradients and suggest residuals or normalization.
|
| 268 |
+
4. Scalability: Suggest ways to make the model more modular or scalable.
|
| 269 |
+
|
| 270 |
+
Format:
|
| 271 |
+
Provide 3-5 concise, bulleted suggestions. Each suggestion should include a "Why" (the technical benefit).`,
|
| 272 |
+
tools: [{ googleSearch: {} }]
|
| 273 |
+
}
|
| 274 |
});
|
| 275 |
return response.text;
|
| 276 |
} catch (error) {
|
|
|
|
| 284 |
export const implementArchitectureSuggestions = async (
|
| 285 |
nodes: Node<NodeData>[],
|
| 286 |
edges: Edge[],
|
| 287 |
+
suggestions: string,
|
| 288 |
+
model: string = DEFAULT_MODEL
|
| 289 |
): Promise<{ nodes: any[], edges: any[] }> => {
|
| 290 |
const ai = getAiClient();
|
| 291 |
|
|
|
|
| 303 |
}))
|
| 304 |
};
|
| 305 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 306 |
try {
|
| 307 |
const response = await ai.models.generateContent({
|
| 308 |
+
model: model,
|
| 309 |
+
contents: `Current Graph Architecture: ${JSON.stringify(graphRepresentation)}\nSuggestions to Implement: "${suggestions}"`,
|
| 310 |
+
config: {
|
| 311 |
+
systemInstruction: `You are a Senior Implementation Engineer specializing in Deep Learning.
|
| 312 |
+
Your task is to modify the provided neural network graph JSON to incorporate specific architectural improvements.
|
| 313 |
+
|
| 314 |
+
Rules:
|
| 315 |
+
1. Maintain the relative layout (x, y positions). If adding nodes, place them logically between existing ones.
|
| 316 |
+
2. Use ONLY valid LayerTypes from the schema.
|
| 317 |
+
3. Ensure all new nodes have unique IDs.
|
| 318 |
+
4. Ensure the resulting graph is fully connected and logically sound.
|
| 319 |
+
5. Preserve existing node parameters unless the suggestion explicitly requires changing them.`,
|
| 320 |
+
responseMimeType: "application/json",
|
| 321 |
+
responseSchema: graphResponseSchema
|
| 322 |
+
}
|
| 323 |
});
|
| 324 |
const result = JSON.parse(response.text.trim());
|
| 325 |
return sanitizeGraph(result);
|
|
|
|
| 334 |
* Checks if 'type' is valid. If not, converts it to a CUSTOM layer to prevent crashes.
|
| 335 |
*/
|
| 336 |
const sanitizeGraph = (graphJson: { nodes: any[], edges: any[] }) => {
|
| 337 |
+
if (!graphJson) return graphJson;
|
| 338 |
|
| 339 |
+
if (graphJson.nodes) {
|
| 340 |
+
graphJson.nodes = graphJson.nodes.map(node => {
|
| 341 |
+
// AI might return type in data.type or top-level type
|
| 342 |
+
let rawType = node.data?.type || node.type || 'Identity';
|
| 343 |
+
|
| 344 |
+
// Ensure rawType is a string
|
| 345 |
+
if (typeof rawType !== 'string') rawType = 'Identity';
|
| 346 |
+
|
| 347 |
+
// Check if this type exists in our known definitions
|
| 348 |
+
const isValid = Object.values(LayerType).includes(rawType as LayerType) && LAYER_DEFINITIONS[rawType as LayerType];
|
| 349 |
+
|
| 350 |
+
if (!isValid) {
|
| 351 |
+
console.warn(`Sanitizing unknown layer type: ${rawType}. Converting to CustomLayer.`);
|
| 352 |
+
return {
|
| 353 |
+
...node,
|
| 354 |
+
type: 'custom', // ReactFlow type
|
| 355 |
+
data: {
|
| 356 |
+
...node.data,
|
| 357 |
+
type: LayerType.CUSTOM,
|
| 358 |
+
label: node.data?.label || rawType,
|
| 359 |
+
params: {
|
| 360 |
+
...(node.data?.params || {}),
|
| 361 |
+
class_name: rawType, // Store original type name here
|
| 362 |
+
args: JSON.stringify(node.data?.params || {}).slice(0, 100) // Rough preserve of args
|
| 363 |
+
}
|
| 364 |
+
}
|
| 365 |
+
};
|
| 366 |
+
}
|
| 367 |
+
|
| 368 |
+
// Ensure standard structure for valid nodes
|
| 369 |
return {
|
| 370 |
...node,
|
| 371 |
+
type: 'custom',
|
| 372 |
data: {
|
| 373 |
...node.data,
|
| 374 |
+
type: rawType,
|
| 375 |
+
params: node.data?.params || {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 376 |
}
|
| 377 |
};
|
| 378 |
+
});
|
| 379 |
+
}
|
| 380 |
+
|
| 381 |
+
if (graphJson.edges) {
|
| 382 |
+
const seenIds = new Set<string>();
|
| 383 |
+
graphJson.edges = graphJson.edges.map((edge, idx) => {
|
| 384 |
+
// AI might return edge without ID or with duplicate ID
|
| 385 |
+
let baseId = edge.id || `e-${edge.source}-${edge.target}`;
|
| 386 |
+
|
| 387 |
+
// Ensure uniqueness within this graph
|
| 388 |
+
let uniqueId = baseId;
|
| 389 |
+
let counter = 1;
|
| 390 |
+
while (seenIds.has(uniqueId)) {
|
| 391 |
+
uniqueId = `${baseId}-${idx}-${counter++}`;
|
| 392 |
}
|
| 393 |
+
seenIds.add(uniqueId);
|
| 394 |
+
|
| 395 |
+
return {
|
| 396 |
+
...edge,
|
| 397 |
+
id: uniqueId,
|
| 398 |
+
animated: true,
|
| 399 |
+
style: { stroke: '#94a3b8' }
|
| 400 |
+
};
|
| 401 |
+
});
|
| 402 |
+
}
|
| 403 |
|
| 404 |
return graphJson;
|
| 405 |
};
|
|
|
|
| 414 |
export const generateGraphWithAgents = async (
|
| 415 |
userPrompt: string,
|
| 416 |
currentNodes: Node<NodeData>[] = [],
|
| 417 |
+
onStatusUpdate: (status: AgentStatus, log: string) => void,
|
| 418 |
+
model: string = DEFAULT_MODEL
|
| 419 |
): Promise<{ nodes: any[], edges: any[] } | null> => {
|
| 420 |
const ai = getAiClient();
|
| 421 |
|
|
|
|
| 432 |
? `Current Graph Context: ${JSON.stringify(currentNodes.map(n => ({ id: n.id, type: n.data.type, label: n.data.label })))}`
|
| 433 |
: "Starting from scratch.";
|
| 434 |
|
| 435 |
+
const architectSystemInstruction = `You are a Senior Neural Network Architect.
|
| 436 |
+
Your task is to design a high-quality neural network graph based on user requirements.
|
| 437 |
+
|
| 438 |
+
Design Principles:
|
| 439 |
+
1. Hierarchical Flow: Arrange nodes from top (Input) to bottom (Output) with y-offsets of +150.
|
| 440 |
+
2. Modern Patterns: Use residuals/skip connections for deep networks. Prefer LayerNorm/RMSNorm for Transformers.
|
| 441 |
+
3. Spatial Clarity: If a layer has multiple branches, spread them horizontally (x-offsets).
|
| 442 |
+
4. Schema Compliance: Use ONLY valid LayerTypes and parameter names provided in the schema.
|
| 443 |
+
|
| 444 |
+
Output Format:
|
| 445 |
+
Return a JSON object with:
|
| 446 |
+
- "reasoning": A brief string explaining your architectural choices.
|
| 447 |
+
- "nodes": Array of { id, type='custom', position:{x,y}, data:{ type, label, params:{} } }
|
| 448 |
+
- "edges": Array of { id, source, target }`;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 449 |
|
| 450 |
let draftJsonStr = "";
|
| 451 |
+
let architectResponse;
|
| 452 |
try {
|
| 453 |
+
architectResponse = await ai.models.generateContent({
|
| 454 |
+
model: model,
|
| 455 |
+
contents: `User Request: "${userPrompt}"\nContext: ${context}\nSchema: ${schemaStr}`,
|
| 456 |
+
config: {
|
| 457 |
+
systemInstruction: architectSystemInstruction,
|
| 458 |
+
// REMOVED responseMimeType and responseSchema as they conflict with tools (googleSearch/urlContext)
|
| 459 |
+
tools: [{ googleSearch: {} }, { urlContext: {} }]
|
| 460 |
+
}
|
| 461 |
});
|
| 462 |
+
draftJsonStr = cleanJsonString(architectResponse.text);
|
| 463 |
} catch (e) {
|
| 464 |
throw e;
|
| 465 |
}
|
|
|
|
| 467 |
// --- Step 2: Critic ---
|
| 468 |
onStatusUpdate('critic', 'Critic is reviewing architecture for flaws...');
|
| 469 |
|
| 470 |
+
const criticSystemInstruction = `You are a Lead Architecture Reviewer.
|
| 471 |
+
Critique the provided neural network draft for technical soundness, efficiency, and adherence to the user's request.
|
| 472 |
+
You have access to Google Search to verify SOTA (State of the Art) practices if needed.
|
| 473 |
+
|
| 474 |
+
Check for:
|
| 475 |
+
- Missing critical components (e.g., no activation after Conv, no pooling before FC).
|
| 476 |
+
- Bottlenecks (e.g., massive dimensionality jumps).
|
| 477 |
+
- Redundancy or inefficient paths.
|
| 478 |
+
- Logical errors in connectivity.
|
| 479 |
+
|
| 480 |
+
If the architecture is excellent, say "No changes needed". Otherwise, provide a concise list of required improvements.`;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 481 |
|
| 482 |
let critique = "";
|
| 483 |
+
let criticResponse;
|
| 484 |
try {
|
| 485 |
+
criticResponse = await ai.models.generateContent({
|
| 486 |
+
model: model,
|
| 487 |
+
// Shared Context: Pass the architect's response as a previous turn to be token efficient
|
| 488 |
+
contents: [
|
| 489 |
+
{ role: 'user', parts: [{ text: `User Request: "${userPrompt}"\nSchema: ${schemaStr}` }] },
|
| 490 |
+
{ role: 'model', parts: [{ text: draftJsonStr }] },
|
| 491 |
+
{ role: 'user', parts: [{ text: "Please critique this architecture." }] }
|
| 492 |
+
],
|
| 493 |
+
config: {
|
| 494 |
+
systemInstruction: criticSystemInstruction,
|
| 495 |
+
tools: [{ googleSearch: {} }]
|
| 496 |
+
}
|
| 497 |
});
|
| 498 |
+
critique = criticResponse.text.trim();
|
| 499 |
} catch (e) {
|
| 500 |
console.warn("Critic agent failed, proceeding with draft.");
|
| 501 |
critique = "No critique available.";
|
|
|
|
| 504 |
// --- Step 3: Refiner ---
|
| 505 |
onStatusUpdate('refiner', 'Refiner is applying fixes and finalizing...');
|
| 506 |
|
| 507 |
+
const refinerSystemInstruction = `You are a Lead Implementation Engineer.
|
| 508 |
+
Finalize the neural network JSON by applying the Critic's feedback to the Architect's draft.
|
| 509 |
+
|
| 510 |
+
Requirements:
|
| 511 |
+
1. Strictly follow the JSON schema: { "nodes": [...], "edges": [...] }.
|
| 512 |
+
2. Ensure all node IDs are unique.
|
| 513 |
+
3. Ensure all LayerTypes match the enum exactly.
|
| 514 |
+
4. Fix any shape or logic errors identified in the critique.
|
| 515 |
+
5. Maintain a clean, readable visual layout.
|
| 516 |
+
|
| 517 |
+
Return ONLY the final JSON object.`;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 518 |
|
| 519 |
try {
|
| 520 |
const response = await ai.models.generateContent({
|
| 521 |
+
model: model,
|
| 522 |
+
// Shared Context: Pass the entire history for maximum consistency
|
| 523 |
+
contents: [
|
| 524 |
+
{ role: 'user', parts: [{ text: `User Request: "${userPrompt}"\nSchema: ${schemaStr}` }] },
|
| 525 |
+
{ role: 'model', parts: [{ text: draftJsonStr }] },
|
| 526 |
+
{ role: 'user', parts: [{ text: `Critique: "${critique}"\n\nApply the critique and provide the final JSON.` }] }
|
| 527 |
+
],
|
| 528 |
+
config: {
|
| 529 |
+
systemInstruction: refinerSystemInstruction,
|
| 530 |
+
// REMOVED responseMimeType and responseSchema as they conflict with previous tool use in history
|
| 531 |
+
}
|
| 532 |
});
|
| 533 |
+
const finalJson = JSON.parse(cleanJsonString(response.text));
|
| 534 |
onStatusUpdate('complete', 'Architecture built successfully!');
|
| 535 |
|
|
|
|
| 536 |
return sanitizeGraph(finalJson);
|
| 537 |
} catch (e) {
|
| 538 |
throw new Error("Refiner agent failed to parse final JSON.");
|
|
|
|
| 545 |
*/
|
| 546 |
export const generateCodeWithAgents = async (
|
| 547 |
promptText: string,
|
| 548 |
+
onStatusUpdate: (status: AgentStatus, log: string) => void,
|
| 549 |
+
model: string = DEFAULT_MODEL
|
| 550 |
): Promise<string> => {
|
| 551 |
const ai = getAiClient();
|
| 552 |
|
| 553 |
+
// --- Step 1: Coder ---
|
| 554 |
onStatusUpdate('architect', 'Coder Agent is writing initial PyTorch implementation...');
|
| 555 |
|
| 556 |
+
const coderSystemInstruction = `You are a Senior Deep Learning Engineer.
|
| 557 |
+
Your task is to write clean, modular, and production-ready PyTorch code.
|
| 558 |
+
|
| 559 |
+
Coding Standards:
|
| 560 |
+
1. Use nn.Module for the main model.
|
| 561 |
+
2. Include docstrings for the class and forward method.
|
| 562 |
+
3. Add comments indicating the expected tensor shapes at each major step.
|
| 563 |
+
4. Use descriptive variable names.
|
| 564 |
+
5. Include a robust 'if __name__ == "__main__":' block that instantiates the model and runs a dummy forward pass with torch.randn.
|
| 565 |
+
6. Handle custom layer logic or external imports if specified.
|
| 566 |
+
|
| 567 |
+
Return ONLY the Python code. No markdown backticks.`;
|
|
|
|
| 568 |
|
| 569 |
let draftCode = "";
|
| 570 |
try {
|
| 571 |
const response = await ai.models.generateContent({
|
| 572 |
+
model: model,
|
| 573 |
+
contents: `Request: "${promptText}"`,
|
| 574 |
+
config: {
|
| 575 |
+
systemInstruction: coderSystemInstruction,
|
| 576 |
+
tools: [{ googleSearch: {} }, { urlContext: {} }]
|
| 577 |
+
}
|
| 578 |
});
|
| 579 |
draftCode = response.text.trim().replace(/```python/g, '').replace(/```/g, '');
|
| 580 |
} catch(e) {
|
| 581 |
throw e;
|
| 582 |
}
|
| 583 |
|
| 584 |
+
// --- Step 2: Reviewer ---
|
| 585 |
onStatusUpdate('critic', 'Reviewer Agent is analyzing code for bugs and optimization...');
|
| 586 |
+
|
| 587 |
+
const reviewerSystemInstruction = `You are a Senior Code Reviewer.
|
| 588 |
+
Analyze the provided PyTorch code for:
|
| 589 |
+
- Syntax errors or missing imports.
|
| 590 |
+
- Logical bugs (e.g., wrong dimension in cat/stack).
|
| 591 |
+
- Missing super().__init__() calls.
|
| 592 |
+
- Inefficient implementations.
|
| 593 |
+
- Adherence to the original architecture request.
|
| 594 |
+
You have access to Google Search to verify PyTorch documentation or SOTA implementation details.
|
| 595 |
+
|
| 596 |
+
Provide a concise, technical critique. If the code is perfect, say "No changes needed".`;
|
| 597 |
+
|
| 598 |
let critique = "";
|
| 599 |
try {
|
| 600 |
const response = await ai.models.generateContent({
|
| 601 |
+
model: model,
|
| 602 |
+
contents: [
|
| 603 |
+
{ role: 'user', parts: [{ text: `Original Request: "${promptText}"` }] },
|
| 604 |
+
{ role: 'model', parts: [{ text: draftCode }] },
|
| 605 |
+
{ role: 'user', parts: [{ text: "Please review this code." }] }
|
| 606 |
+
],
|
| 607 |
+
config: {
|
| 608 |
+
systemInstruction: reviewerSystemInstruction,
|
| 609 |
+
tools: [{ googleSearch: {} }]
|
| 610 |
+
}
|
| 611 |
});
|
| 612 |
critique = response.text.trim();
|
| 613 |
} catch(e) {
|
| 614 |
critique = "No critique available.";
|
| 615 |
}
|
| 616 |
|
| 617 |
+
// --- Step 3: Polisher ---
|
| 618 |
onStatusUpdate('refiner', 'Polisher Agent is finalizing the codebase...');
|
| 619 |
+
|
| 620 |
+
const polisherSystemInstruction = `You are a Principal Software Engineer.
|
| 621 |
+
Refine the PyTorch code by incorporating the Reviewer's feedback.
|
| 622 |
+
|
| 623 |
+
Goals:
|
| 624 |
+
1. Fix all bugs and style issues.
|
| 625 |
+
2. Ensure the code is strictly valid Python.
|
| 626 |
+
3. Maintain all docstrings and shape comments.
|
| 627 |
+
4. Ensure the test block works perfectly.
|
| 628 |
+
|
| 629 |
+
Return ONLY the final Python code. No markdown backticks.`;
|
| 630 |
|
| 631 |
try {
|
| 632 |
const response = await ai.models.generateContent({
|
| 633 |
+
model: model,
|
| 634 |
+
contents: [
|
| 635 |
+
{ role: 'user', parts: [{ text: `Original Request: "${promptText}"` }] },
|
| 636 |
+
{ role: 'model', parts: [{ text: draftCode }] },
|
| 637 |
+
{ role: 'user', parts: [{ text: `Critique: ${critique}\n\nApply the critique and provide the final polished code.` }] }
|
| 638 |
+
],
|
| 639 |
+
config: { systemInstruction: polisherSystemInstruction }
|
| 640 |
});
|
| 641 |
+
let finalCode = cleanJsonString(response.text);
|
| 642 |
onStatusUpdate('complete', 'Code generation complete!');
|
| 643 |
return finalCode;
|
| 644 |
} catch(e) {
|
|
|
|
| 653 |
nodes: Node<NodeData>[],
|
| 654 |
edges: Edge[],
|
| 655 |
errorMsg: string,
|
| 656 |
+
onStatusUpdate: (status: AgentStatus, log: string) => void,
|
| 657 |
+
model: string = DEFAULT_MODEL
|
| 658 |
): Promise<{ nodes: any[], edges: any[] } | null> => {
|
| 659 |
const ai = getAiClient();
|
| 660 |
const graphJson = JSON.stringify({
|
|
|
|
| 665 |
// --- Step 1: Debugger ---
|
| 666 |
onStatusUpdate('debugger', 'Debugger Agent is analyzing the error trace...');
|
| 667 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 668 |
let debugAnalysis = "";
|
| 669 |
try {
|
| 670 |
const response = await ai.models.generateContent({
|
| 671 |
+
model: model,
|
| 672 |
+
contents: `Graph: ${graphJson}\nError Message: "${errorMsg}"`,
|
| 673 |
+
config: {
|
| 674 |
+
systemInstruction: `You are a Senior Systems Debugger specializing in Deep Learning.
|
| 675 |
+
Analyze the provided graph and error message to identify the root cause of the failure.
|
| 676 |
+
You have access to Google Search to look up specific PyTorch error traces or layer compatibility issues.
|
| 677 |
+
|
| 678 |
+
Focus on:
|
| 679 |
+
- Shape mismatches between connected layers.
|
| 680 |
+
- Incorrect parameter values.
|
| 681 |
+
- Missing required layers (e.g., Flatten before Linear).
|
| 682 |
+
|
| 683 |
+
Output a concise, technical explanation of the bug.`,
|
| 684 |
+
tools: [{ googleSearch: {} }]
|
| 685 |
+
}
|
| 686 |
});
|
| 687 |
debugAnalysis = response.text.trim();
|
| 688 |
} catch (e) {
|
|
|
|
| 692 |
// --- Step 2: Architect ---
|
| 693 |
onStatusUpdate('architect', 'Architect Agent is planning the fix...');
|
| 694 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 695 |
let fixPlan = "";
|
| 696 |
try {
|
| 697 |
const response = await ai.models.generateContent({
|
| 698 |
+
model: model,
|
| 699 |
+
contents: `Issue Analysis: ${debugAnalysis}`,
|
| 700 |
+
config: {
|
| 701 |
+
systemInstruction: `You are a Solution Architect.
|
| 702 |
+
Based on the debugger's analysis, propose a minimal and effective structural fix.
|
| 703 |
+
|
| 704 |
+
Instructions:
|
| 705 |
+
- Specify which nodes to add, remove, or modify.
|
| 706 |
+
- Specify which edges to reconnect.
|
| 707 |
+
- Ensure the fix doesn't introduce new errors.
|
| 708 |
+
|
| 709 |
+
Output the plan in clear, actionable steps.`
|
| 710 |
+
}
|
| 711 |
});
|
| 712 |
fixPlan = response.text.trim();
|
| 713 |
} catch (e) {
|
| 714 |
fixPlan = "Apply necessary structural corrections.";
|
| 715 |
}
|
| 716 |
|
| 717 |
+
// --- Step 2.5: Critic (New Step) ---
|
| 718 |
+
onStatusUpdate('critic', 'Critic Agent is reviewing the fix plan...');
|
| 719 |
+
|
| 720 |
+
let fixCritique = "";
|
| 721 |
+
try {
|
| 722 |
+
const response = await ai.models.generateContent({
|
| 723 |
+
model: model,
|
| 724 |
+
contents: `Error: "${errorMsg}"\nAnalysis: ${debugAnalysis}\nProposed Fix: ${fixPlan}`,
|
| 725 |
+
config: {
|
| 726 |
+
systemInstruction: `You are a Senior Reviewer.
|
| 727 |
+
Review the proposed fix plan. Does it actually solve the root cause?
|
| 728 |
+
Does it introduce new shape mismatches?
|
| 729 |
+
Is it the most efficient way to fix the error?
|
| 730 |
+
You have access to Google Search to verify SOTA fixes for similar issues.
|
| 731 |
+
|
| 732 |
+
Output a brief critique or "Approved" if perfect.`,
|
| 733 |
+
tools: [{ googleSearch: {} }]
|
| 734 |
+
}
|
| 735 |
+
});
|
| 736 |
+
fixCritique = response.text.trim();
|
| 737 |
+
} catch (e) {
|
| 738 |
+
fixCritique = "Approved";
|
| 739 |
+
}
|
| 740 |
+
|
| 741 |
// --- Step 3: Patcher ---
|
| 742 |
onStatusUpdate('patcher', 'Patcher Agent is applying the fix to the graph...');
|
| 743 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 744 |
try {
|
| 745 |
const response = await ai.models.generateContent({
|
| 746 |
+
model: model,
|
| 747 |
+
contents: `Current Graph: ${graphJson}\nFix Plan: ${fixPlan}\nCritique: ${fixCritique}`,
|
| 748 |
+
config: {
|
| 749 |
+
systemInstruction: `You are a Senior Patcher.
|
| 750 |
+
Apply the proposed fix plan (considering the critique) to the graph JSON.
|
| 751 |
+
|
| 752 |
+
Requirements:
|
| 753 |
+
1. Return a valid JSON object with "nodes" and "edges".
|
| 754 |
+
2. Use ONLY valid LayerTypes.
|
| 755 |
+
3. Maintain existing node IDs and positions where possible.
|
| 756 |
+
4. Ensure all new nodes have unique IDs.
|
| 757 |
+
|
| 758 |
+
Return ONLY the final JSON object.`,
|
| 759 |
+
// REMOVED responseMimeType and responseSchema as they conflict with previous tool use in history
|
| 760 |
+
}
|
| 761 |
});
|
| 762 |
+
const finalJson = JSON.parse(cleanJsonString(response.text));
|
| 763 |
onStatusUpdate('complete', 'Fix applied successfully!');
|
| 764 |
return sanitizeGraph(finalJson);
|
| 765 |
} catch (e) {
|
| 766 |
throw new Error("Patcher agent failed to generate valid JSON.");
|
| 767 |
}
|
| 768 |
+
};
|