""" n8n Workflow Generator - Gradio Web Interface Deploy this to Hugging Face Spaces """ import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer from peft import PeftModel import torch import json import re # ============================================================================== # CONFIGURATION # ============================================================================== MODEL_REPO = "Nishan30/n8n-workflow-generator" # Update with your HF repo BASE_MODEL = "Qwen/Qwen2.5-Coder-1.5B-Instruct" # Memory optimization: Set to True for 8-bit quantization (uses less memory but slower) USE_8BIT = False # Change to True if you get out-of-memory errors # ============================================================================== # MODEL LOADING # ============================================================================== def load_model(): """Load model once and cache it""" print("Loading model...") # Prepare model loading kwargs model_kwargs = { "device_map": "auto", "trust_remote_code": True, "low_cpu_mem_usage": True, } # Use 8-bit quantization if enabled (saves memory) if USE_8BIT: print("Using 8-bit quantization for memory efficiency...") model_kwargs["load_in_8bit"] = True else: model_kwargs["torch_dtype"] = torch.float16 # Load base model with memory optimization base_model = AutoModelForCausalLM.from_pretrained( BASE_MODEL, **model_kwargs ) # Load LoRA adapter with error handling for unsupported parameters try: model = PeftModel.from_pretrained( base_model, MODEL_REPO, ) except TypeError as e: if "unexpected keyword argument" in str(e): print(f"â ī¸ Warning: {e}") print("Attempting to load with filtered config...") # Download and modify config from huggingface_hub import hf_hub_download import tempfile import shutil config_path = hf_hub_download(repo_id=MODEL_REPO, filename="adapter_config.json") with open(config_path, 'r') as f: config = json.load(f) # Remove unsupported parameters unsupported_params = ['alora_invocation_tokens', 'alora_invocation_token_ids'] for param in unsupported_params: if param in config: print(f"Removing unsupported parameter: {param}") del config[param] # Save modified config to temp directory temp_dir = tempfile.mkdtemp() temp_config_path = f"{temp_dir}/adapter_config.json" with open(temp_config_path, 'w') as f: json.dump(config, f, indent=2) # Copy other adapter files for filename in ['adapter_model.safetensors', 'adapter_model.bin']: try: src = hf_hub_download(repo_id=MODEL_REPO, filename=filename) shutil.copy(src, f"{temp_dir}/{filename}") break except: continue # Load from temp directory model = PeftModel.from_pretrained( base_model, temp_dir, ) # Cleanup shutil.rmtree(temp_dir) else: raise tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO) # Set pad token if not present if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token print("Model loaded successfully!") return model, tokenizer # Load model at startup (global variable for caching) print("đ Loading model at startup...") model, tokenizer = load_model() print("â Model loaded and ready!") # ============================================================================== # CODE GENERATION # ============================================================================== def generate_workflow(prompt, temperature=0.5, max_tokens=1024): """Generate n8n workflow code from prompt""" if not prompt.strip(): return "Please enter a workflow description.", None, None # ULTRA-OPTIMIZED: Minimal tokens, maximal parameter guidance formatted_prompt = f"""### System: You are an expert n8n workflow generator. ## Core Rules: 1. Always start with a trigger node (webhook, scheduleTrigger, manualTrigger, formTrigger, emailTrigger) 2. **CRITICAL: Fill ALL required parameters with realistic values** 3. Connect nodes with .to() method 4. Use descriptive workflow names ## Parameter Rules (MUST FOLLOW): **GitHub nodes:** {{"owner": "username", "repository": "repo-name", "resource": "issue", "operation": "getAll"}} **Email (gmail/email):** {{"to": "user@example.com", "subject": "Subject", "message": "Content"}} **Messaging (slack/telegram/discord):** {{"channel": "#general", "text": "Message"}} **HTTP requests:** {{"url": "https://api.example.com/endpoint", "method": "GET"}} **Schedule triggers:** {{"rule": {{"interval": [{{"field": "cronExpression", "expression": "0 9 * * *"}}]}}}} **Database (postgres/mysql):** {{"operation": "select", "table": "table_name"}} **Conditionals (if/switch):** {{"conditions": [{{"value1": "={{{{ $json.field }}}}", "operation": "equals", "value2": "value"}}]}} **Error handlers:** {{"message": "Error: description"}} ## Common Nodes: Triggers: webhook, scheduleTrigger, manualTrigger, formTrigger, emailTrigger Actions: slack, gmail, telegram, discord, httpRequest, googleSheets, airtable, notion Processing: if, switch, set, filter, merge, split, aggregate, sort, code, function Utilities: wait, noOp, stopAndError **IMPORTANT:** Infer parameter values from user request (e.g., "GitHub issues" â resource: "issue", operation: "getAll", "daily" â cron: "0 9 * * *") Generate ONLY TypeScript DSL code in ```typescript blocks. ### Instruction: {prompt} ### Response: """ # Debug: Print formatted prompt (first 500 chars) print(f"\n{'='*60}") print(f"User Prompt: {prompt}") print(f"Formatted Input (truncated):\n{formatted_prompt[:500]}...") print(f"{'='*60}\n") # Tokenize inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device) input_length = inputs.input_ids.shape[1] print(f"Input tokens: {input_length}, Max new tokens: {max_tokens}") # Generate with parameters matching training with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=max_tokens, temperature=max(temperature, 0.1), do_sample=True, top_p=0.95, top_k=50, repetition_penalty=1.1, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id, ) # Decode generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) # Debug: Print generated text print(f"Generated text length: {len(generated_text)} chars") print(f"Generated text (first 500 chars):\n{generated_text[:500]}...\n") # Extract code from response (handle ### Response: format) code = extract_code_from_instruction_format(generated_text) # Convert to n8n JSON n8n_json = convert_to_n8n_json(code) # Create visualization visualization = create_visualization(n8n_json) return code, json.dumps(n8n_json, indent=2), visualization def extract_code_from_instruction_format(text): """Extract TypeScript code from ### Response: format""" # Split by ### Response: and get the part after it try: response_part = text.split("### Response:")[-1].strip() except: response_part = text # Remove any subsequent ### markers (like ### Instruction:, ### System:) for stop_marker in ["### Instruction:", "### System:", "\n\n\n\n"]: if stop_marker in response_part: response_part = response_part.split(stop_marker)[0].strip() # Try to extract code from markdown blocks code_match = re.search(r'```(?:typescript|ts)?\n(.*?)```', response_part, re.DOTALL) if code_match: return code_match.group(1).strip() # Remove markdown code block markers if present response_part = re.sub(r'```(?:typescript|ts)?', '', response_part) return response_part.strip() def extract_code(text): """Legacy extraction function - kept for compatibility""" return extract_code_from_instruction_format(text) # ============================================================================== # N8N JSON CONVERSION # ============================================================================== def parse_js_object(js_obj_str): """Convert JavaScript object notation to Python dict""" if not js_obj_str or js_obj_str.strip() == "{}": return {} try: # First try direct JSON parsing return json.loads(js_obj_str) except: pass try: # Convert JS object notation to JSON # Replace single quotes with double quotes json_str = js_obj_str.replace("'", '"') # Add quotes around unquoted keys (e.g., {path: "data"} -> {"path": "data"}) json_str = re.sub(r'(\w+):', r'"\1":', json_str) # Parse the JSON return json.loads(json_str) except Exception as e: print(f"Warning: Could not parse parameters '{js_obj_str}': {e}") return {} def extract_balanced_braces(text, start_pos): """Extract content within balanced braces starting at start_pos""" if start_pos >= len(text) or text[start_pos] != '{': return None brace_count = 0 in_string = False escape_next = False string_char = None for i in range(start_pos, len(text)): char = text[i] if escape_next: escape_next = False continue if char == '\\': escape_next = True continue if char in ('"', "'") and not in_string: in_string = True string_char = char elif char == string_char and in_string: in_string = False string_char = None elif char == '{' and not in_string: brace_count += 1 elif char == '}' and not in_string: brace_count -= 1 if brace_count == 0: return text[start_pos:i+1] return None def convert_to_n8n_json(typescript_code): """Convert TypeScript DSL to n8n JSON format""" nodes = [] connections = {} workflow_name = "Generated Workflow" # Extract workflow name name_match = re.search(r"new Workflow\(['\"](.*?)['\"]\)", typescript_code) if name_match: workflow_name = name_match.group(1) # Extract node definitions - find all workflow.add() calls node_pattern = r'const\s+(\w+)\s*=\s*workflow\.add\([\'"]([^\'\"]+)[\'"]' node_map = {} # variable name -> node id position_y = 250 position_x = 300 for match in re.finditer(node_pattern, typescript_code): var_name = match.group(1) node_type = match.group(2) # Look for parameters after the node type params_str = "{}" remaining_text = typescript_code[match.end():] # Check if there's a comma followed by parameters comma_match = re.match(r'\s*,\s*', remaining_text) if comma_match: param_start = match.end() + comma_match.end() if param_start < len(typescript_code) and typescript_code[param_start] == '{': params_str = extract_balanced_braces(typescript_code, param_start) if params_str is None: params_str = "{}" # Convert JavaScript object notation to valid JSON parameters = parse_js_object(params_str) node_id = str(len(nodes)) node_map[var_name] = node_id nodes.append({ "id": node_id, "name": var_name, "type": node_type, "typeVersion": 1, "position": [position_x, position_y], "parameters": parameters }) position_x += 300 # Extract connections connection_pattern = r'(\w+)\.to\((\w+)\)' connection_matches = re.finditer(connection_pattern, typescript_code) for match in connection_matches: source_var = match.group(1) target_var = match.group(2) if source_var in node_map and target_var in node_map: source_id = node_map[source_var] target_id = node_map[target_var] # Find source node name source_node = next((n for n in nodes if n["id"] == source_id), None) if source_node: source_name = source_node["name"] if source_name not in connections: connections[source_name] = {"main": [[]] } connections[source_name]["main"][0].append({ "node": target_var, "type": "main", "index": 0 }) return { "name": workflow_name, "nodes": nodes, "connections": connections, "active": False, "settings": {} } # ============================================================================== # VISUALIZATION # ============================================================================== def create_visualization(n8n_json): """Create HTML visualization of the workflow""" nodes = n8n_json.get("nodes", []) connections = n8n_json.get("connections", {}) if not nodes: return "
{node_type}
{value_str}