File size: 3,070 Bytes
3f2babe
 
 
 
 
 
eb84831
3f2babe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
---
license: mit
---
This is a stable version of BSJCode, a model capable of fixing as well as optimizing java code.

------------------
 ```bash
## How to use it:

from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch

# Load the model and tokenizer
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
    llm_int8_enable_fp32_cpu_offload=True 
)
model = AutoModelForCausalLM.from_pretrained(
    "BSAtlas/BSJCode-1-Stable",
    quantization_config=bnb_config,
    device_map="auto"
).to(device="cuda")
tokenizer = AutoTokenizer.from_pretrained("BSAtlas/BSJCode-1-Stable")


def detect_and_fix_bugs(code_snippet):
    # Prepare the prompt
    prompt = f"""You are an expert Java code optimizer and bug fixer. 
    Analyze the following code, identify any bugs or inefficiencies, 
    and provide an optimized and corrected version:

    ```java
    {code_snippet}
    ```

    Optimized and Fixed Code:"""

    # Tokenize the input
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024)

    # Generate code
    with torch.no_grad():
        outputs = model.generate(
            input_ids=inputs['input_ids'],
            attention_mask=inputs['attention_mask'],
            max_length=1024,  # Adjust based on your model's training
            num_return_sequences=1,
            do_sample=True,
            temperature=0.7,
            top_k=50,
            top_p=0.95,
            repetition_penalty=1.2
        )
    
    # Decode the output
    generated_code = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # Extract the code portion after the prompt
    code_start = generated_code.find("Optimized and Fixed Code:")
    if code_start != -1:
        fixed_code = generated_code[code_start + len("Optimized and Fixed Code:"):].strip()
    else:
        fixed_code = generated_code

    return fixed_code


    sample_code = """
public class ThreadSafetyExample {
    private int counter = 0;

    public void increment() {
        // Not thread-safe method
        counter++;
    }

    public int getCounter() {
        return counter;
    }

    public static void main(String[] args) {
        ThreadSafetyExample example = new ThreadSafetyExample();

        Thread t1 = new Thread(() -> {
            for (int i = 0; i < 1000; i++) {
                example.increment();
            }
        });

        Thread t2 = new Thread(() -> {
            for (int i = 0; i < 1000; i++) {
                example.increment();
            }
        });

        t1.start();
        t2.start();

        try {
            t1.join();
            t2.join();
        } catch (InterruptedException e) {
            e.printStackTrace();
        }

        System.out.println("Final Counter: " + example.getCounter());
    }
}
"""
fixed_code = detect_and_fix_bugs(sample_code)
print(fixed_code)
--------------------------------------------------------------------------