CodeInsight / codeInsight /inference /code_assistant.py
GitHub Actions
Sync from GitHub Actions
7dab830
import torch
import os
import sys
from transformers import AutoModelForCausalLM, AutoTokenizer
from codeInsight.utils.config import load_config
from codeInsight.exception import ExceptionHandle
from codeInsight.logger import logging
class CodeAssistant:
def __init__(self, config_path="config/model.yaml"):
try:
self.config = load_config(config_path)
self.dataset_config = self.config['dataset']
model_repo = self.config['paths']['final_model_repo']
logging.info(f"Initializing CodeAssistant with model from: {model_repo}")
self.model = AutoModelForCausalLM.from_pretrained(
model_repo,
device_map="auto",
torch_dtype=torch.bfloat16,
trust_remote_code=False
)
self.tokenizer = AutoTokenizer.from_pretrained(
model_repo
)
self.model.eval()
self.model.config.use_cache = True
logging.info("CodeAssistant initialized successfully.")
except Exception as e:
logging.error("Failed to initialize CodeAssistant")
raise ExceptionHandle(e, sys)
def _formet_prompt(self, prompt : str) -> str:
return f"{self.dataset_config['SYSTEM_PROMPT']}{self.dataset_config['USER_TOKEN']}{prompt}{self.dataset_config['END_TOKEN']}\n\n{self.dataset_config['ASSISTANT_TOKEN']}"
def generate(self, prompt : str, max_length : int = 512, temperature: float = 0.1, top_p : float =0.80) -> str:
try:
input_text = self._formet_prompt(prompt)
inputs = self.tokenizer(
input_text,
return_tensors="pt",
).to(self.model.device)
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=max_length,
temperature=temperature,
top_p=top_p,
do_sample=True,
eos_token_id=self.tokenizer.convert_tokens_to_ids(self.dataset_config['END_TOKEN']),
pad_token_id=self.tokenizer.eos_token_id
)
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
if self.dataset_config['ASSISTANT_TOKEN'] in generated_text:
generated_code = generated_text.split(self.dataset_config['ASSISTANT_TOKEN'])[1].strip()
if self.dataset_config['END_TOKEN'] in generated_code:
generated_code = generated_code.split(self.dataset_config['END_TOKEN'])[0].strip()
else:
generated_code = generated_text
logging.info("Response generated successfully.")
return generated_code
except Exception as e:
logging.error("Failed during code generation")
raise ExceptionHandle(e, sys)