BSJCode-1-Stable / README.md
BSAtlas's picture
Update README.md
3f2babe verified
|
raw
history blame
3.06 kB
metadata
license: mit

This is a stable version of BSJCode, a model capable of fixing as well as optimizing java code.


How to use it:

from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig import torch

Load the model and tokenizer

bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, llm_int8_enable_fp32_cpu_offload=True ) model = AutoModelForCausalLM.from_pretrained( "BSAtlas/BSJCode-1-Stable", quantization_config=bnb_config, device_map="auto" ).to(device="cuda") tokenizer = AutoTokenizer.from_pretrained("BSAtlas/BSJCode-1-Stable")

def detect_and_fix_bugs(code_snippet): # Prepare the prompt prompt = f"""You are an expert Java code optimizer and bug fixer. Analyze the following code, identify any bugs or inefficiencies, and provide an optimized and corrected version:

```java
{code_snippet}
```

Optimized and Fixed Code:"""

# Tokenize the input
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024)

# Generate code
with torch.no_grad():
    outputs = model.generate(
        input_ids=inputs['input_ids'],
        attention_mask=inputs['attention_mask'],
        max_length=1024,  # Adjust based on your model's training
        num_return_sequences=1,
        do_sample=True,
        temperature=0.7,
        top_k=50,
        top_p=0.95,
        repetition_penalty=1.2
    )

# Decode the output
generated_code = tokenizer.decode(outputs[0], skip_special_tokens=True)

# Extract the code portion after the prompt
code_start = generated_code.find("Optimized and Fixed Code:")
if code_start != -1:
    fixed_code = generated_code[code_start + len("Optimized and Fixed Code:"):].strip()
else:
    fixed_code = generated_code

return fixed_code


sample_code = """

public class ThreadSafetyExample { private int counter = 0;

public void increment() {
    // Not thread-safe method
    counter++;
}

public int getCounter() {
    return counter;
}

public static void main(String[] args) {
    ThreadSafetyExample example = new ThreadSafetyExample();

    Thread t1 = new Thread(() -> {
        for (int i = 0; i < 1000; i++) {
            example.increment();
        }
    });

    Thread t2 = new Thread(() -> {
        for (int i = 0; i < 1000; i++) {
            example.increment();
        }
    });

    t1.start();
    t2.start();

    try {
        t1.join();
        t2.join();
    } catch (InterruptedException e) {
        e.printStackTrace();
    }

    System.out.println("Final Counter: " + example.getCounter());
}

} """ fixed_code = detect_and_fix_bugs(sample_code) print(fixed_code)