wuhp commited on
Commit
152392b
·
verified ·
1 Parent(s): f916579

Upload geminiService.ts

Browse files
Files changed (1) hide show
  1. services/geminiService.ts +468 -0
services/geminiService.ts ADDED
@@ -0,0 +1,468 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { GoogleGenAI } from "@google/genai";
2
+ import { Node, Edge } from 'reactflow';
3
+ import { NodeData, LayerType } from '../types';
4
+ import { LAYER_DEFINITIONS } from '../constants';
5
+
6
+ // Key Management
7
+ let userApiKey = typeof window !== 'undefined' ? localStorage.getItem('gemini_api_key') || '' : '';
8
+ const defaultEnvKey = process.env.API_KEY || '';
9
+
10
+ export const setUserApiKey = (key: string) => {
11
+ userApiKey = key;
12
+ if (typeof window !== 'undefined') {
13
+ localStorage.setItem('gemini_api_key', key);
14
+ }
15
+ };
16
+
17
+ export const getUserApiKey = () => userApiKey || defaultEnvKey;
18
+
19
+ const getAiClient = () => {
20
+ const key = getUserApiKey();
21
+ if (!key) throw new Error("API Key is missing. Please add your Gemini API Key.");
22
+ return new GoogleGenAI({ apiKey: key });
23
+ };
24
+
25
+ const MODEL_NAME = 'gemini-2.5-flash';
26
+
27
+ export type AgentStatus = 'idle' | 'architect' | 'critic' | 'refiner' | 'complete' | 'error';
28
+
29
+ /**
30
+ * Internal helper to build the raw specification string.
31
+ * This ensures all data is captured before sending to the AI for refinement.
32
+ */
33
+ const buildRawPrompt = (nodes: Node<NodeData>[], edges: Edge[]): string => {
34
+ // Sort nodes by vertical position to list them in a logical flow
35
+ const sortedNodes = [...nodes].sort((a, b) => a.position.y - b.position.y);
36
+
37
+ let rawSpec = "### Raw Architecture Data\n\n";
38
+
39
+ rawSpec += "**1. Nodes (Layers):**\n";
40
+ sortedNodes.forEach(node => {
41
+ // Format parameters nicely, handling objects
42
+ const params = Object.entries(node.data.params)
43
+ .map(([k, v]) => {
44
+ if (typeof v === 'object' && v !== null) {
45
+ return `${k}=${JSON.stringify(v)}`;
46
+ }
47
+ return `${k}=${v}`;
48
+ })
49
+ .join(', ');
50
+
51
+ rawSpec += `- [ID: ${node.id}] TYPE: ${node.data.type} | LABEL: ${node.data.label}\n`;
52
+ if (params) {
53
+ rawSpec += ` PARAMS: ${params}\n`;
54
+ }
55
+
56
+ // Specific instruction for Custom Layers
57
+ if (node.data.type === LayerType.CUSTOM) {
58
+ rawSpec += ` CUSTOM_NOTE: Instantiate using class '${node.data.params.class_name}' with args '${node.data.params.args}'.\n`;
59
+ }
60
+ });
61
+
62
+ rawSpec += "\n**2. Connectivity (Edges):**\n";
63
+ if (edges.length === 0) {
64
+ rawSpec += "- No connections defined.\n";
65
+ } else {
66
+ edges.forEach(edge => {
67
+ const sourceNode = nodes.find(n => n.id === edge.source);
68
+ const targetNode = nodes.find(n => n.id === edge.target);
69
+ const sourceName = sourceNode ? `${sourceNode.data.label} (ID:${sourceNode.id})` : edge.source;
70
+ const targetName = targetNode ? `${targetNode.data.label} (ID:${targetNode.id})` : edge.target;
71
+
72
+ rawSpec += `- ${sourceName} -> ${targetName}\n`;
73
+ });
74
+ }
75
+
76
+ return rawSpec;
77
+ };
78
+
79
+ /**
80
+ * Generates a polished, professional prompt using the AI.
81
+ * It takes the raw hardcoded spec and asks the AI to format it perfectly for a coding LLM.
82
+ */
83
+ export const generateRefinedPrompt = async (nodes: Node<NodeData>[], edges: Edge[]): Promise<string> => {
84
+ const ai = getAiClient();
85
+ const rawSpec = buildRawPrompt(nodes, edges);
86
+
87
+ const systemPrompt = `
88
+ You are an expert AI Prompt Engineer for Deep Learning.
89
+ Your goal is to take a raw, technical neural network specification and rewrite it into a
90
+ perfect, professional, and detailed prompt that another AI (like a coding assistant) could use to write flawless PyTorch code.
91
+
92
+ Input Raw Specification:
93
+ ${rawSpec}
94
+
95
+ Instructions:
96
+ 1. Start the output with: "You are an expert Deep Learning Engineer. Please write a complete, runnable PyTorch model code for the following neural network architecture:"
97
+ 2. Create a section "Architecture Specification".
98
+ 3. List "1. Layers (Nodes)" cleanly. Include ID, Type, and Parameters.
99
+ 4. List "2. Connectivity (Forward Pass Flow)" cleanly.
100
+ - Explicitly describe merge points (e.g. "Node X receives inputs from A and B. Handle this merge...").
101
+ - Note specific handling for complex layers like CrossAttention (needs Query + Key/Value) or SAM Decoders.
102
+ 5. Create a section "Implementation Requirements" with standard PyTorch best practices (nn.Module, forward method, correct shapes).
103
+ 6. Do NOT write the Python code yourself. Write the PROMPT that asks for the code.
104
+ 7. Ensure the tone is technical and precise.
105
+
106
+ Return ONLY the generated prompt text.
107
+ `;
108
+
109
+ try {
110
+ const response = await ai.models.generateContent({
111
+ model: MODEL_NAME,
112
+ contents: systemPrompt,
113
+ });
114
+ return response.text.trim();
115
+ } catch (error) {
116
+ console.error("Prompt refinement failed:", error);
117
+ return buildRawPrompt(nodes, edges) + "\n\n(AI Refinement failed, showing raw spec)";
118
+ }
119
+ };
120
+
121
+ export const validateArchitecture = async (nodes: Node<NodeData>[], edges: Edge[]): Promise<string> => {
122
+ const ai = getAiClient();
123
+ const graphRepresentation = {
124
+ nodes: nodes.map(n => ({
125
+ id: n.id,
126
+ type: n.data.type,
127
+ parameters: n.data.params
128
+ })),
129
+ edges: edges.map(e => ({
130
+ source: e.source,
131
+ target: e.target
132
+ }))
133
+ };
134
+
135
+ const prompt = `
136
+ Analyze this neural network architecture graph for validity.
137
+ Graph: ${JSON.stringify(graphRepresentation)}
138
+
139
+ Check for:
140
+ 1. Shape mismatches (e.g., Conv2D output to Linear without Flatten).
141
+ 2. Disconnected components.
142
+ 3. Logical errors (e.g., MaxPool after Output).
143
+ 4. Merge layer correctness (Concat/Add needs multiple inputs).
144
+ 5. GenAI correctness (e.g., CrossAttention needs 2 inputs, VLM projection dims match).
145
+
146
+ Return a concise report. If valid, say "Architecture is valid.". If invalid, list specific errors and suggest fixes.
147
+ `;
148
+
149
+ try {
150
+ const response = await ai.models.generateContent({
151
+ model: MODEL_NAME,
152
+ contents: prompt,
153
+ });
154
+ return response.text;
155
+ } catch (error) {
156
+ return "Error validating architecture.";
157
+ }
158
+ }
159
+
160
+ /**
161
+ * Gets AI suggestions for improving the architecture.
162
+ */
163
+ export const getArchitectureSuggestions = async (nodes: Node<NodeData>[], edges: Edge[]): Promise<string> => {
164
+ const ai = getAiClient();
165
+ const graphRepresentation = {
166
+ nodes: nodes.map(n => ({
167
+ id: n.id,
168
+ type: n.data.type,
169
+ parameters: n.data.params,
170
+ label: n.data.label
171
+ })),
172
+ edges: edges.map(e => ({
173
+ source: e.source,
174
+ target: e.target
175
+ }))
176
+ };
177
+
178
+ const prompt = `
179
+ You are a Senior Deep Learning Architect. Review the following neural network architecture graph.
180
+
181
+ Graph Structure:
182
+ ${JSON.stringify(graphRepresentation, null, 2)}
183
+
184
+ Task: Provide 3 to 5 concrete, actionable suggestions to improve this model.
185
+ Focus on:
186
+ - Modern best practices (e.g., using LayerNorm vs BatchNorm in Transformers, SwiGLU vs ReLU).
187
+ - Architecture efficiency and parameter count optimization.
188
+ - Potential bottlenecks or vanishing gradient risks.
189
+ - Adding residuals or skip connections if the model is deep.
190
+
191
+ Format the output as a clean bulleted list. Keep it concise, professional and helpful.
192
+ `;
193
+
194
+ try {
195
+ const response = await ai.models.generateContent({
196
+ model: MODEL_NAME,
197
+ contents: prompt,
198
+ });
199
+ return response.text;
200
+ } catch (error) {
201
+ return "Error generating suggestions.";
202
+ }
203
+ }
204
+
205
+ /**
206
+ * Implements the suggestions automatically.
207
+ */
208
+ export const implementArchitectureSuggestions = async (
209
+ nodes: Node<NodeData>[],
210
+ edges: Edge[],
211
+ suggestions: string
212
+ ): Promise<{ nodes: any[], edges: any[] }> => {
213
+ const ai = getAiClient();
214
+
215
+ const graphRepresentation = {
216
+ nodes: nodes.map(n => ({
217
+ id: n.id,
218
+ type: n.data.type,
219
+ parameters: n.data.params,
220
+ label: n.data.label,
221
+ position: n.position
222
+ })),
223
+ edges: edges.map(e => ({
224
+ source: e.source,
225
+ target: e.target
226
+ }))
227
+ };
228
+
229
+ const prompt = `
230
+ You are a Senior Implementation Engineer.
231
+ Task: Apply the following architectural suggestions to the provided graph JSON.
232
+
233
+ Current Graph:
234
+ ${JSON.stringify(graphRepresentation)}
235
+
236
+ Suggestions to Implement:
237
+ "${suggestions}"
238
+
239
+ Instructions:
240
+ 1. Modify the nodes and edges to incorporate the suggestions.
241
+ 2. Maintain the layout (x, y positions) as best as possible, offsetting new nodes if added.
242
+ 3. Ensure all LayerTypes are valid from the standard schema.
243
+ 4. Return the complete, updated JSON with "nodes" and "edges" arrays.
244
+ 5. Return ONLY raw JSON.
245
+ `;
246
+
247
+ try {
248
+ const response = await ai.models.generateContent({
249
+ model: MODEL_NAME,
250
+ contents: prompt,
251
+ config: { responseMimeType: "application/json" }
252
+ });
253
+ return JSON.parse(response.text.trim());
254
+ } catch (error) {
255
+ throw new Error("Failed to implement suggestions.");
256
+ }
257
+ };
258
+
259
+
260
+ /**
261
+ * Multi-agent graph generation.
262
+ * 1. Architect Agent: Drafts the layout.
263
+ * 2. Critic Agent: Reviews for errors.
264
+ * 3. Refiner Agent: Produces final JSON.
265
+ */
266
+ export const generateGraphWithAgents = async (
267
+ userPrompt: string,
268
+ currentNodes: Node<NodeData>[] = [],
269
+ onStatusUpdate: (status: AgentStatus, log: string) => void
270
+ ): Promise<{ nodes: any[], edges: any[] } | null> => {
271
+ const ai = getAiClient();
272
+
273
+ const layerSchema = Object.values(LAYER_DEFINITIONS).map(l => ({
274
+ type: l.type,
275
+ params: l.parameters.map(p => ({ name: p.name, type: p.type, options: p.options }))
276
+ }));
277
+ const schemaStr = JSON.stringify(layerSchema);
278
+
279
+ // --- Step 1: Architect ---
280
+ onStatusUpdate('architect', 'Architect is drafting initial layout...');
281
+
282
+ const context = currentNodes.length > 0
283
+ ? `Current Graph Context: ${JSON.stringify(currentNodes.map(n => ({ id: n.id, type: n.data.type, label: n.data.label })))}`
284
+ : "Starting from scratch.";
285
+
286
+ const architectPrompt = `
287
+ Role: Senior Neural Network Architect.
288
+ Task: Create a preliminary graph layout (JSON) for the user request.
289
+ User Request: "${userPrompt}"
290
+ Context: ${context}
291
+
292
+ Available Layers: ${Object.keys(LAYER_DEFINITIONS).join(', ')}
293
+ Schema Reference: ${schemaStr}
294
+
295
+ Requirements:
296
+ 1. Output valid JSON with "nodes" and "edges" arrays.
297
+ 2. "nodes": { id, type='custom', position:{x,y}, data:{ type: LayerType, label: string, params: {} } }
298
+ 3. Use correct LayerTypes from enum.
299
+ 4. Layout nodes vertically (y+150 each step).
300
+ 5. Connect edges logically.
301
+ 6. If multi-input/output, arrange horizontally.
302
+
303
+ Return ONLY raw JSON.
304
+ `;
305
+
306
+ let draftJsonStr = "";
307
+ try {
308
+ const response = await ai.models.generateContent({
309
+ model: MODEL_NAME,
310
+ contents: architectPrompt,
311
+ config: { responseMimeType: "application/json" }
312
+ });
313
+ draftJsonStr = response.text.trim();
314
+ } catch (e) {
315
+ throw new Error("Architect agent failed to generate draft.");
316
+ }
317
+
318
+ // --- Step 2: Critic ---
319
+ onStatusUpdate('critic', 'Critic is reviewing architecture for flaws...');
320
+
321
+ const criticPrompt = `
322
+ Role: Senior Lead Reviewer.
323
+ Task: Critique the following neural network architecture draft.
324
+ User Request: "${userPrompt}"
325
+ Draft Architecture: ${draftJsonStr}
326
+
327
+ Check for:
328
+ - Shape mismatches (e.g. 3D output into 2D input without flattening)
329
+ - Logical connection errors
330
+ - Missing essential layers (e.g. Activations, Normalization)
331
+ - Parameter errors (e.g. kernel size too large)
332
+ - Compliance with user request
333
+
334
+ Output a concise paragraph describing specific improvements needed. If perfect, say "No changes needed".
335
+ `;
336
+
337
+ let critique = "";
338
+ try {
339
+ const response = await ai.models.generateContent({
340
+ model: MODEL_NAME,
341
+ contents: criticPrompt,
342
+ });
343
+ critique = response.text.trim();
344
+ } catch (e) {
345
+ console.warn("Critic agent failed, proceeding with draft.");
346
+ critique = "No critique available.";
347
+ }
348
+
349
+ // --- Step 3: Refiner ---
350
+ onStatusUpdate('refiner', 'Refiner is applying fixes and finalizing...');
351
+
352
+ const refinerPrompt = `
353
+ Role: Lead Engineer.
354
+ Task: Finalize the JSON architecture based on the critique.
355
+
356
+ Draft: ${draftJsonStr}
357
+ Critique: "${critique}"
358
+
359
+ Instructions:
360
+ 1. Apply the fixes suggested in the critique.
361
+ 2. Ensure the JSON structure is strictly { "nodes": [...], "edges": [...] }.
362
+ 3. Ensure all node IDs are unique strings.
363
+ 4. Ensure parameter values match the schema types.
364
+ 5. Ensure "type" in top level node object is always 'custom'.
365
+
366
+ Return ONLY the final JSON.
367
+ `;
368
+
369
+ try {
370
+ const response = await ai.models.generateContent({
371
+ model: MODEL_NAME,
372
+ contents: refinerPrompt,
373
+ config: { responseMimeType: "application/json" }
374
+ });
375
+ const finalJson = JSON.parse(response.text.trim());
376
+ onStatusUpdate('complete', 'Architecture built successfully!');
377
+ return finalJson;
378
+ } catch (e) {
379
+ throw new Error("Refiner agent failed to parse final JSON.");
380
+ }
381
+ };
382
+
383
+ /**
384
+ * Multi-agent code generation.
385
+ * Uses the detailed prompt to write, review, and polish PyTorch code.
386
+ */
387
+ export const generateCodeWithAgents = async (
388
+ promptText: string,
389
+ onStatusUpdate: (status: AgentStatus, log: string) => void
390
+ ): Promise<string> => {
391
+ const ai = getAiClient();
392
+
393
+ // --- Step 1: Coder (Architect) ---
394
+ onStatusUpdate('architect', 'Coder Agent is writing initial PyTorch implementation...');
395
+
396
+ const coderPrompt = `
397
+ Role: Senior Deep Learning Engineer.
398
+ Task: Write complete, runnable PyTorch code based on the following architecture prompt.
399
+ Prompt: "${promptText}"
400
+
401
+ Requirements:
402
+ - Use torch.nn.Module
403
+ - Include all necessary imports
404
+ - Handle forward pass logic exactly as described
405
+ - Include a 'if __name__ == "__main__":' block to test with dummy data
406
+
407
+ Return ONLY the Python code. No markdown formatting.
408
+ `;
409
+
410
+ let draftCode = "";
411
+ try {
412
+ const response = await ai.models.generateContent({
413
+ model: MODEL_NAME,
414
+ contents: coderPrompt
415
+ });
416
+ draftCode = response.text.trim().replace(/```python/g, '').replace(/```/g, '');
417
+ } catch(e) {
418
+ throw new Error("Coder agent failed.");
419
+ }
420
+
421
+ // --- Step 2: Reviewer (Critic) ---
422
+ onStatusUpdate('critic', 'Reviewer Agent is analyzing code for bugs and optimization...');
423
+ const reviewPrompt = `
424
+ Role: Code Reviewer.
425
+ Task: Review the following PyTorch code for errors, shape mismatches, or style issues.
426
+ Code:
427
+ ${draftCode}
428
+
429
+ Original Prompt Request: "${promptText}"
430
+
431
+ Output a concise critique. If perfect, say "No changes needed".
432
+ `;
433
+
434
+ let critique = "";
435
+ try {
436
+ const response = await ai.models.generateContent({
437
+ model: MODEL_NAME,
438
+ contents: reviewPrompt
439
+ });
440
+ critique = response.text.trim();
441
+ } catch(e) {
442
+ critique = "No critique available.";
443
+ }
444
+
445
+ // --- Step 3: Polisher (Refiner) ---
446
+ onStatusUpdate('refiner', 'Polisher Agent is finalizing the codebase...');
447
+ const polisherPrompt = `
448
+ Role: Senior Software Engineer.
449
+ Task: Refine the PyTorch code based on the critique.
450
+
451
+ Draft Code: ${draftCode}
452
+ Critique: ${critique}
453
+
454
+ Return ONLY the final Python code. No markdown formatting.
455
+ `;
456
+
457
+ try {
458
+ const response = await ai.models.generateContent({
459
+ model: MODEL_NAME,
460
+ contents: polisherPrompt
461
+ });
462
+ let finalCode = response.text.trim().replace(/```python/g, '').replace(/```/g, '');
463
+ onStatusUpdate('complete', 'Code generation complete!');
464
+ return finalCode;
465
+ } catch(e) {
466
+ throw new Error("Polisher agent failed.");
467
+ }
468
+ };