Spaces:
Sleeping
Sleeping
| import torch | |
| import os | |
| import sys | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from codeInsight.utils.config import load_config | |
| from codeInsight.exception import ExceptionHandle | |
| from codeInsight.logger import logging | |
| class CodeAssistant: | |
| def __init__(self, config_path="config/model.yaml"): | |
| try: | |
| self.config = load_config(config_path) | |
| self.dataset_config = self.config['dataset'] | |
| model_repo = self.config['paths']['final_model_repo'] | |
| logging.info(f"Initializing CodeAssistant with model from: {model_repo}") | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| model_repo, | |
| device_map="auto", | |
| torch_dtype=torch.bfloat16, | |
| trust_remote_code=False | |
| ) | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| model_repo | |
| ) | |
| self.model.eval() | |
| self.model.config.use_cache = True | |
| logging.info("CodeAssistant initialized successfully.") | |
| except Exception as e: | |
| logging.error("Failed to initialize CodeAssistant") | |
| raise ExceptionHandle(e, sys) | |
| def _formet_prompt(self, prompt : str) -> str: | |
| return f"{self.dataset_config['SYSTEM_PROMPT']}{self.dataset_config['USER_TOKEN']}{prompt}{self.dataset_config['END_TOKEN']}\n\n{self.dataset_config['ASSISTANT_TOKEN']}" | |
| def generate(self, prompt : str, max_length : int = 512, temperature: float = 0.1, top_p : float =0.80) -> str: | |
| try: | |
| input_text = self._formet_prompt(prompt) | |
| inputs = self.tokenizer( | |
| input_text, | |
| return_tensors="pt", | |
| ).to(self.model.device) | |
| with torch.no_grad(): | |
| outputs = self.model.generate( | |
| **inputs, | |
| max_new_tokens=max_length, | |
| temperature=temperature, | |
| top_p=top_p, | |
| do_sample=True, | |
| eos_token_id=self.tokenizer.convert_tokens_to_ids(self.dataset_config['END_TOKEN']), | |
| pad_token_id=self.tokenizer.eos_token_id | |
| ) | |
| generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| if self.dataset_config['ASSISTANT_TOKEN'] in generated_text: | |
| generated_code = generated_text.split(self.dataset_config['ASSISTANT_TOKEN'])[1].strip() | |
| if self.dataset_config['END_TOKEN'] in generated_code: | |
| generated_code = generated_code.split(self.dataset_config['END_TOKEN'])[0].strip() | |
| else: | |
| generated_code = generated_text | |
| logging.info("Response generated successfully.") | |
| return generated_code | |
| except Exception as e: | |
| logging.error("Failed during code generation") | |
| raise ExceptionHandle(e, sys) |