| """ |
| Minimal reproducible inference script for sweep-next-edit-v2-7B. |
| |
| This model predicts the next edit a developer will make given: |
| - the current file contents |
| - recent changes (diffs) |
| - the cursor position |
| - (optional) retrieval chunks from other files |
| |
| Usage: |
| python inference.py |
| |
| Requires: transformers, torch, accelerate |
| pip install transformers torch accelerate |
| """ |
|
|
| import torch |
| from dataclasses import dataclass |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
| MODEL_ID = "sweepai/sweep-next-edit-v2-7B" |
|
|
| |
| PROMPT_TEMPLATE = """<|file_sep|>{file_path} |
| {initial_file}{retrieval_results} |
| {recent_changes} |
| <|file_sep|>original/{file_path}:{start_line}:{end_line} |
| {prev_section} |
| <|file_sep|>current/{file_path}:{start_line}:{end_line} |
| {code_block} |
| <|file_sep|>updated/{file_path}:{start_line}:{end_line} |
| {prefill}""" |
|
|
| DIFF_FORMAT = """<|file_sep|>{file_path}:{start_line}:{end_line} |
| original: |
| {old_code} |
| updated: |
| {new_code}""" |
|
|
| STOP_TOKENS = ["<|endoftext|>", "<|file_sep|>"] |
| MAX_NEW_TOKENS = 1024 |
|
|
|
|
| @dataclass |
| class FileChunk: |
| """A chunk of code from another file, used for cross-file context (retrieval).""" |
| file_path: str |
| content: str |
|
|
| def to_string(self) -> str: |
| return f"<|file_sep|>{self.file_path}\n{self.content}\n" |
|
|
|
|
| def compute_prefill( |
| code_block: str, |
| relative_cursor: int, |
| changes_above_cursor: bool = False, |
| ) -> str: |
| """ |
| Compute the prefill string — the portion of the updated code block that we |
| feed to the model so it only has to generate starting from the edit point. |
| |
| The model's job is to produce the full "updated" code block. But most of it |
| is unchanged — only a small region near the cursor is different. So we |
| "prefill" the output with the unchanged prefix, and the model just continues |
| from there. |
| |
| Two strategies depending on what the user just did: |
| |
| changes_above_cursor=True (last action was an insertion): |
| The user just inserted text above the cursor. The lines above the cursor |
| may have shifted, so we can't trust them as a prefill — the model might |
| need to edit them. We only prefill the very first line of the code block |
| (plus any blank lines after it), giving the model freedom to rewrite |
| everything from line 2 onward. |
| |
| Example: code_block is 11 lines, cursor on line 10. |
| Prefill = line 1 + any trailing blank lines = " if n <= 0:\n" |
| Model generates lines 2-11. |
| |
| changes_above_cursor=False (last action was NOT an insertion): |
| The user did something else (navigation, deletion, etc). The lines above |
| the cursor are likely stable, so we prefill up to the cursor line. This |
| constrains the model to only edit at/below the cursor. |
| |
| We prefill everything before the cursor's line (up to the last newline |
| before cursor position), so the model starts generating from the cursor |
| line itself. |
| |
| Example: code_block is 11 lines, cursor on line 10 col 0. |
| Prefill = lines 1-9 (everything up to the last \\n before cursor). |
| Model generates lines 10-11. |
| """ |
| if changes_above_cursor: |
| |
| prefill = code_block[:relative_cursor] |
| prefilled_lines = prefill.splitlines(True) |
|
|
| NUM_LINES_ABOVE = 1 |
| before_split = "".join(prefilled_lines[:NUM_LINES_ABOVE]) |
| after_split = "".join(prefilled_lines[NUM_LINES_ABOVE:]) |
|
|
| |
| |
| |
| for char in after_split: |
| if char == "\n": |
| before_split += "\n" |
| else: |
| break |
|
|
| return before_split |
| else: |
| |
| prefix_before_cursor = code_block[:relative_cursor] |
| if "\n" not in prefix_before_cursor: |
| |
| return "" |
| prefill_end = prefix_before_cursor.rfind("\n") + 1 |
| return code_block[:prefill_end] |
|
|
|
|
| def is_pure_insertion_above_cursor( |
| code_block: str, completion: str, relative_cursor: int |
| ) -> bool: |
| """ |
| Reject completions that only insert new lines above the cursor without |
| actually editing the cursor line. These are low-value predictions — |
| the model is just guessing what new code to add rather than fixing |
| an existing reference. |
| """ |
| current_line_index = len(code_block[:relative_cursor].splitlines(True)) |
| code_block_lines = code_block.splitlines(True) |
| cursor_line = code_block_lines[current_line_index - 1] |
|
|
| if code_block.strip() == completion.strip(): |
| return False |
| if not cursor_line.strip(): |
| return False |
|
|
| prefix_lines = code_block_lines[:current_line_index - 1] |
| prefix = "".join(prefix_lines) |
| suffix_lines = code_block_lines[current_line_index:] |
| suffix = "".join(suffix_lines) |
|
|
| |
| |
| if completion.startswith(prefix) and completion.endswith(cursor_line + suffix): |
| return True |
|
|
| return False |
|
|
|
|
| def build_prompt( |
| file_path: str, |
| file_contents: str, |
| cursor_position: int, |
| recent_changes: str = "", |
| retrieval_chunks: list[FileChunk] | None = None, |
| file_chunks: list[FileChunk] | None = None, |
| changes_above_cursor: bool = False, |
| num_lines_before: int = 10, |
| num_lines_after: int = 10, |
| ) -> tuple[str, str, int, int]: |
| """ |
| Build the model prompt from file contents and cursor position. |
| |
| Args: |
| file_path: Path of the file being edited. |
| file_contents: Full contents of the file after the user's latest edit. |
| cursor_position: Character offset of the cursor in file_contents. |
| recent_changes: Formatted diff string of recent changes (use DIFF_FORMAT). |
| retrieval_chunks: Cross-file context chunks (e.g. related functions from |
| other files). Placed AFTER recent_changes in the prompt for optimal |
| KV cache reuse. |
| file_chunks: Additional file context chunks. Prepended to the prompt. |
| changes_above_cursor: Whether the user's last action was an insertion. |
| Controls the prefill strategy (see compute_prefill). |
| num_lines_before: Lines of code to include before cursor in the block. |
| num_lines_after: Lines of code to include after cursor in the block. |
| |
| Returns: |
| (formatted_prompt, code_block, block_start_index, relative_cursor) |
| """ |
| lines = file_contents.splitlines(True) |
|
|
| |
| pos = 0 |
| cursor_line = 0 |
| for i, line in enumerate(lines): |
| if pos + len(line) > cursor_position: |
| cursor_line = i |
| break |
| pos += len(line) |
| else: |
| cursor_line = len(lines) - 1 |
|
|
| |
| block_start = max(0, cursor_line - num_lines_before) |
| block_end = min(len(lines), cursor_line + num_lines_after + 1) |
| code_block = "".join(lines[block_start:block_end]) |
| block_start_index = sum(len(l) for l in lines[:block_start]) |
|
|
| |
| relative_cursor = cursor_position - block_start_index |
|
|
| |
| code_block_with_cursor = ( |
| code_block[:relative_cursor] |
| + "<|cursor|>" |
| + code_block[relative_cursor:] |
| ) |
|
|
| |
| prev_section = code_block |
|
|
| |
| prefill = compute_prefill(code_block, relative_cursor, changes_above_cursor) |
|
|
| |
| context_start = max(0, cursor_line - 150) |
| context_end = min(len(lines), cursor_line + 150) |
| initial_file = "".join(lines[context_start:context_end]) |
|
|
| |
| retrieval_results = "" |
| if retrieval_chunks: |
| retrieval_results = "".join( |
| f"\n{chunk.to_string()}" for chunk in retrieval_chunks |
| ) |
|
|
| start_line = block_start + 1 |
| end_line = block_end |
|
|
| formatted = PROMPT_TEMPLATE.format( |
| file_path=file_path, |
| initial_file=initial_file, |
| retrieval_results=retrieval_results, |
| recent_changes=recent_changes, |
| prev_section=prev_section, |
| code_block=code_block_with_cursor, |
| start_line=start_line, |
| end_line=end_line, |
| prefill=prefill, |
| ) |
|
|
| |
| if file_chunks: |
| formatted = "".join(c.to_string() for c in file_chunks) + formatted |
|
|
| return formatted, code_block, block_start_index, relative_cursor |
|
|
|
|
| def generate(model, tokenizer, prompt: str, device: str = "cuda") -> str: |
| """Run inference and return the completion (the predicted updated code block).""" |
| inputs = tokenizer(prompt, return_tensors="pt").to(device) |
|
|
| stop_token_ids = [ |
| tokenizer.convert_tokens_to_ids(t) |
| for t in STOP_TOKENS |
| if t in tokenizer.get_vocab() |
| ] |
| eos_ids = list(set(stop_token_ids + [tokenizer.eos_token_id])) |
|
|
| with torch.no_grad(): |
| outputs = model.generate( |
| **inputs, |
| max_new_tokens=MAX_NEW_TOKENS, |
| do_sample=False, |
| eos_token_id=eos_ids, |
| pad_token_id=tokenizer.eos_token_id, |
| ) |
|
|
| new_tokens = outputs[0][inputs["input_ids"].shape[1]:] |
| completion = tokenizer.decode(new_tokens, skip_special_tokens=False) |
|
|
| |
| for stop in STOP_TOKENS: |
| if stop in completion: |
| completion = completion[: completion.index(stop)] |
|
|
| return completion |
|
|
|
|
| def main(): |
| |
| file_path = "example.py" |
| file_contents = """\ |
| def fibonacci(n): |
| if n <= 0: |
| return 0 |
| elif n == 1: |
| return 1 |
| else: |
| return fibonacci(n - 1) + fibonacci(n - 2) |
| |
| |
| def main(): |
| for i in range(10): |
| print(fibonacci(i)) |
| """ |
|
|
| |
| |
| edited_contents = file_contents.replace( |
| "return fibonacci(n - 1) + fibonacci(n - 2)", |
| "return fib(n - 1) + fib(n - 2)", |
| ).replace( |
| "def fibonacci(n):", |
| "def fib(n):", |
| ) |
|
|
| |
| cursor_line_text = " print(fibonacci(i))" |
| cursor_position = edited_contents.index(cursor_line_text) |
|
|
| |
| recent_changes = DIFF_FORMAT.format( |
| file_path=file_path, |
| start_line=1, |
| end_line=7, |
| old_code="def fibonacci(n):\n return fibonacci(n - 1) + fibonacci(n - 2)", |
| new_code="def fib(n):\n return fib(n - 1) + fib(n - 2)", |
| ) |
|
|
| |
| retrieval_chunks = [ |
| FileChunk( |
| file_path="utils.py", |
| content="def fib_memo(n, memo={}):\n if n in memo:\n return memo[n]\n memo[n] = fib_memo(n-1) + fib_memo(n-2)\n return memo[n]", |
| ) |
| ] |
|
|
| |
| |
| |
| prompt, code_block, block_start, relative_cursor = build_prompt( |
| file_path=file_path, |
| file_contents=edited_contents, |
| cursor_position=cursor_position, |
| recent_changes=recent_changes, |
| retrieval_chunks=retrieval_chunks, |
| changes_above_cursor=False, |
| ) |
|
|
| print("=" * 60) |
| print("PROMPT") |
| print("=" * 60) |
| print(prompt) |
| print() |
|
|
| |
| device = "mps" if torch.backends.mps.is_available() else "cpu" |
| print(f"Loading model {MODEL_ID} on {device}...") |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) |
| model = AutoModelForCausalLM.from_pretrained( |
| MODEL_ID, |
| dtype=torch.bfloat16, |
| device_map=device, |
| trust_remote_code=True, |
| ) |
|
|
| print("Running inference...") |
| completion = generate(model, tokenizer, prompt, device=device) |
|
|
| |
| if is_pure_insertion_above_cursor(code_block, completion, relative_cursor): |
| print("Rejected: model only inserted above cursor without editing cursor line.") |
| return |
|
|
| print("=" * 60) |
| print("MODEL OUTPUT (predicted updated code block)") |
| print("=" * 60) |
| print(completion) |
| print() |
|
|
| |
| print("=" * 60) |
| print("DIFF") |
| print("=" * 60) |
| print(f"Original code block:\n{code_block}") |
| print(f"Updated code block:\n{completion}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|