Nishan30 commited on
Commit
3ef49c3
·
verified ·
1 Parent(s): 0f83a0b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -59
app.py CHANGED
@@ -17,26 +17,6 @@ import re
17
  MODEL_REPO = "Nishan30/n8n-workflow-generator" # Update with your HF repo
18
  BASE_MODEL = "Qwen/Qwen2.5-Coder-1.5B-Instruct"
19
 
20
- SYSTEM_PROMPT = """You are an expert n8n workflow generator. Given a user's request, you MUST generate clean, functional TypeScript code that EXACTLY matches their specific requirements using the @n8n-generator/core DSL.
21
-
22
- CRITICAL: Generate code based ONLY on the user's request. Do NOT use example workflows. Create the workflow the user asks for.
23
-
24
- Your output should:
25
- - Only contain the code, no explanations
26
- - Use the Workflow class from @n8n-generator/core
27
- - Use workflow.add() to create nodes with appropriate parameters
28
- - Use .to() or workflow.connect() for connections
29
- - Match the user's specific requirements exactly
30
- - Be ready to compile directly to n8n JSON
31
-
32
- Format:
33
- ```typescript
34
- const workflow = new Workflow('Descriptive Name');
35
- const node1 = workflow.add('n8n-nodes-base.nodetype', { param: 'value' });
36
- const node2 = workflow.add('n8n-nodes-base.nodetype', { param: 'value' });
37
- node1.to(node2);
38
- ```"""
39
-
40
  # ==============================================================================
41
  # MODEL LOADING
42
  # ==============================================================================
@@ -120,41 +100,46 @@ def generate_workflow(prompt, temperature=0.5, max_tokens=1024):
120
  if not prompt.strip():
121
  return "Please enter a workflow description.", None, None
122
 
123
- # Format messages
124
- messages = [
125
- {"role": "system", "content": SYSTEM_PROMPT},
126
- {"role": "user", "content": prompt}
127
- ]
128
-
129
- # Apply chat template
130
- text = tokenizer.apply_chat_template(
131
- messages,
132
- tokenize=False,
133
- add_generation_prompt=True
134
- )
 
 
 
 
135
 
136
  # Debug: Print formatted prompt (first 500 chars)
137
  print(f"\n{'='*60}")
138
  print(f"User Prompt: {prompt}")
139
- print(f"Formatted Input (truncated):\n{text[:500]}...")
140
  print(f"{'='*60}\n")
141
 
142
  # Tokenize
143
- inputs = tokenizer(text, return_tensors="pt").to(model.device)
144
  input_length = inputs.input_ids.shape[1]
145
  print(f"Input tokens: {input_length}, Max new tokens: {max_tokens}")
146
 
147
- # Generate with better sampling parameters
148
  with torch.no_grad():
149
  outputs = model.generate(
150
  **inputs,
151
  max_new_tokens=max_tokens,
152
- temperature=max(temperature, 0.1), # Ensure minimum temperature
153
- do_sample=True, # Always sample for variety
154
- top_p=0.9,
155
- top_k=50, # Add top-k sampling
156
- repetition_penalty=1.15, # Increase to reduce repetition
157
- no_repeat_ngram_size=3 # Prevent repeating 3-grams
 
158
  )
159
 
160
  # Decode
@@ -162,10 +147,10 @@ def generate_workflow(prompt, temperature=0.5, max_tokens=1024):
162
 
163
  # Debug: Print generated text
164
  print(f"Generated text length: {len(generated_text)} chars")
165
- print(f"Generated text (first 300 chars):\n{generated_text[:300]}...\n")
166
 
167
- # Extract code from response
168
- code = extract_code(generated_text)
169
 
170
  # Convert to n8n JSON
171
  n8n_json = convert_to_n8n_json(code)
@@ -175,25 +160,33 @@ def generate_workflow(prompt, temperature=0.5, max_tokens=1024):
175
 
176
  return code, json.dumps(n8n_json, indent=2), visualization
177
 
178
- def extract_code(text):
179
- """Extract TypeScript code from generated text"""
180
 
181
- # Try to find code block
182
- code_match = re.search(r'```(?:typescript|ts)?\n(.*?)```', text, re.DOTALL)
 
 
 
 
 
 
 
 
 
 
 
183
  if code_match:
184
  return code_match.group(1).strip()
185
 
186
- # If no code block, look for code after assistant response
187
- if "assistant" in text.lower():
188
- parts = text.split("assistant", 1)
189
- if len(parts) > 1:
190
- # Remove any markdown code blocks
191
- code = parts[1].strip()
192
- code = re.sub(r'```(?:typescript|ts)?\n', '', code)
193
- code = re.sub(r'```', '', code)
194
- return code.strip()
195
 
196
- return text.strip()
 
 
 
 
197
 
198
  # ==============================================================================
199
  # N8N JSON CONVERSION
@@ -495,4 +488,4 @@ if __name__ == "__main__":
495
  server_name="0.0.0.0",
496
  server_port=7860,
497
  share=False
498
- )
 
17
  MODEL_REPO = "Nishan30/n8n-workflow-generator" # Update with your HF repo
18
  BASE_MODEL = "Qwen/Qwen2.5-Coder-1.5B-Instruct"
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  # ==============================================================================
21
  # MODEL LOADING
22
  # ==============================================================================
 
100
  if not prompt.strip():
101
  return "Please enter a workflow description.", None, None
102
 
103
+ # IMPORTANT: Use the exact format the model was trained with
104
+ formatted_prompt = f"""### System:
105
+ You are an expert n8n workflow generator. Given a user's request, you generate clean, functional TypeScript code using the @n8n-generator/core DSL.
106
+
107
+ Your output should:
108
+ - Only contain the code, no explanations
109
+ - Use the Workflow class from @n8n-generator/core
110
+ - Use workflow.add() to create nodes
111
+ - Use .to() or workflow.connect() for connections
112
+ - Be ready to compile directly to n8n JSON
113
+
114
+ ### Instruction:
115
+ {prompt}
116
+
117
+ ### Response:
118
+ """
119
 
120
  # Debug: Print formatted prompt (first 500 chars)
121
  print(f"\n{'='*60}")
122
  print(f"User Prompt: {prompt}")
123
+ print(f"Formatted Input (truncated):\n{formatted_prompt[:500]}...")
124
  print(f"{'='*60}\n")
125
 
126
  # Tokenize
127
+ inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
128
  input_length = inputs.input_ids.shape[1]
129
  print(f"Input tokens: {input_length}, Max new tokens: {max_tokens}")
130
 
131
+ # Generate with parameters matching training
132
  with torch.no_grad():
133
  outputs = model.generate(
134
  **inputs,
135
  max_new_tokens=max_tokens,
136
+ temperature=max(temperature, 0.1),
137
+ do_sample=True,
138
+ top_p=0.95,
139
+ top_k=50,
140
+ repetition_penalty=1.1,
141
+ eos_token_id=tokenizer.eos_token_id,
142
+ pad_token_id=tokenizer.pad_token_id,
143
  )
144
 
145
  # Decode
 
147
 
148
  # Debug: Print generated text
149
  print(f"Generated text length: {len(generated_text)} chars")
150
+ print(f"Generated text (first 500 chars):\n{generated_text[:500]}...\n")
151
 
152
+ # Extract code from response (handle ### Response: format)
153
+ code = extract_code_from_instruction_format(generated_text)
154
 
155
  # Convert to n8n JSON
156
  n8n_json = convert_to_n8n_json(code)
 
160
 
161
  return code, json.dumps(n8n_json, indent=2), visualization
162
 
163
+ def extract_code_from_instruction_format(text):
164
+ """Extract TypeScript code from ### Response: format"""
165
 
166
+ # Split by ### Response: and get the part after it
167
+ try:
168
+ response_part = text.split("### Response:")[-1].strip()
169
+ except:
170
+ response_part = text
171
+
172
+ # Remove any subsequent ### markers (like ### Instruction:, ### System:)
173
+ for stop_marker in ["### Instruction:", "### System:", "\n\n\n\n"]:
174
+ if stop_marker in response_part:
175
+ response_part = response_part.split(stop_marker)[0].strip()
176
+
177
+ # Try to extract code from markdown blocks
178
+ code_match = re.search(r'```(?:typescript|ts)?\n(.*?)```', response_part, re.DOTALL)
179
  if code_match:
180
  return code_match.group(1).strip()
181
 
182
+ # Remove markdown code block markers if present
183
+ response_part = re.sub(r'```(?:typescript|ts)?', '', response_part)
 
 
 
 
 
 
 
184
 
185
+ return response_part.strip()
186
+
187
+ def extract_code(text):
188
+ """Legacy extraction function - kept for compatibility"""
189
+ return extract_code_from_instruction_format(text)
190
 
191
  # ==============================================================================
192
  # N8N JSON CONVERSION
 
488
  server_name="0.0.0.0",
489
  server_port=7860,
490
  share=False
491
+ )