File size: 5,223 Bytes
c109b62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    AutoModelForCausalLM
)
import difflib, re, torch

# Models and their types
MODEL_CONFIGS = {
    "Salesforce/codet5-base": "seq2seq",        # CodeT5
    "EleutherAI/gpt-neo-1.3B": "causal",        # GPT-Neo
    "microsoft/CodeGPT-small-py": "causal",     # CodeGPT-small (Python)
}

# Load tokenizers and models
tokenizers, models = {}, {}

for name, mtype in MODEL_CONFIGS.items():
    tokenizers[name] = AutoTokenizer.from_pretrained(name)
    if mtype == "seq2seq":
        models[name] = AutoModelForSeq2SeqLM.from_pretrained(name)
    else:
        models[name] = AutoModelForCausalLM.from_pretrained(name)

# Rule-based fixes
SECURE_REPLACEMENTS = {
    "hashlib.md5": ("hashlib.sha256", "MD5 is weak, replaced with SHA-256."),
    "hashlib.sha1": ("hashlib.sha256", "SHA1 is weak, replaced with SHA-256."),
    "eval(": ("ast.literal_eval(", "Unsafe eval removed, replaced with safe literal_eval."),
    "pickle.load(": ("# pickle.load removed", "pickle.load is unsafe, consider json/safe loaders."),
}

def rule_based_patch(code: str):
    explanations = []
    patched = code
    for bad, (good, reason) in SECURE_REPLACEMENTS.items():
        if bad in patched:
            patched = patched.replace(bad, good)
            explanations.append({"change": f"{bad}{good}", "reason": reason})
    return patched, explanations

def preserve_structure(original: str, enhanced: str):
    """Ensure imports and function signatures remain if model drops them."""
    final_code = enhanced
    original_imports = [l for l in original.splitlines() if l.strip().startswith("import")]
    for imp in original_imports:
        if imp not in final_code:
            final_code = imp + "\n" + final_code
    original_defs = [l for l in original.splitlines() if l.strip().startswith("def ")]
    for d in original_defs:
        if d.split("(")[0] not in final_code:
            final_code = d + "\n    # [!] Function body missing, please review\n" + final_code
    return final_code

def create_diff(original: str, enhanced: str):
    """Return structured diff for frontend rendering."""
    diff_lines = difflib.unified_diff(
        original.splitlines(), enhanced.splitlines(),
        fromfile="Original", tofile="Enhanced", lineterm=""
    )
    formatted = []
    for line in diff_lines:
        if line.startswith("+") and not line.startswith("+++"):
            formatted.append({"type": "add", "content": line[1:]})
        elif line.startswith("-") and not line.startswith("---"):
            formatted.append({"type": "remove", "content": line[1:]})
        elif not line.startswith("@@"):
            formatted.append({"type": "context", "content": line})
    return formatted

def postprocess_code(code: str):
    code = re.sub(r'^"""|"""$', '', code.strip())
    lines = code.splitlines()
    return "\n".join([l.replace("\t", "    ").rstrip() for l in lines])

def run_model(model_name, code, language):
    tokenizer = tokenizers[model_name]
    model = models[model_name]
    mtype = MODEL_CONFIGS[model_name]

    prompt = f"fix {language} code: {code}"

    if mtype == "seq2seq":
        inputs = tokenizer.encode(prompt, return_tensors="pt", truncation=True, max_length=512)
        outputs = model.generate(inputs, max_length=512, num_beams=4, early_stopping=True)
        return tokenizer.decode(outputs[0], skip_special_tokens=True)
    else:
        inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
        outputs = model.generate(
            **inputs,
            max_new_tokens=256,
            temperature=0.3,
            top_p=0.95,
            do_sample=False
        )
        return tokenizer.decode(outputs[0], skip_special_tokens=True)

def enhance_code(code: str, language: str):
    try:
        patched_code, rule_explanations = rule_based_patch(code)

        candidates = []
        for m in MODEL_CONFIGS.keys():
            try:
                enhanced = run_model(m, patched_code, language)
                enhanced = postprocess_code(enhanced)
                enhanced = preserve_structure(code, enhanced)
                candidates.append({"model": m, "code": enhanced})
            except Exception as e:
                candidates.append({"model": m, "code": f"# [!] Failed: {str(e)}"})

        best = max(candidates, key=lambda c: len(c["code"]))
        diff = create_diff(code, best["code"])

        explanations = rule_explanations + [
            {"change": "Model improvements", "reason": "Best candidate chosen among ensemble"}
        ]

        return {
            "enhanced_code": best["code"],
            "diff": diff,
            "candidates": candidates[:3],
            "explanations": explanations
        }

    except Exception as e:
        fallback = code + f"\n# [!] Enhancer crashed: {str(e)}"
        return {
            "enhanced_code": fallback,
            "diff": create_diff(code, fallback),
            "candidates": [],
            "explanations": [{"change": "Error", "reason": str(e)}]
        }