| import warnings | |
| import torch | |
| from peft import PeftModel | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig | |
| warnings.filterwarnings("ignore") | |
| model_name = "google/gemma-2b" | |
| adapters_name = "./lora_weights" | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.bfloat16 | |
| ) | |
| hf_token = "YOUR_TOKEN_HERE" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token) | |
| model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=bnb_config, device_map={"":0}, token=hf_token) | |
| model = PeftModel.from_pretrained(model, adapters_name) | |
| model = model.merge_and_unload() | |
| def format_query(query): | |
| text = f"Instruction: {query} \n\n Output: " | |
| device = "cuda:0" | |
| inputs = tokenizer(text, return_tensors="pt").to(device) | |
| outputs = model.generate(**inputs, max_new_tokens=120) | |
| return tokenizer.decode(outputs[0], skip_special_tokens=False).split("Output:")[1].split("<eos>")[0].split("Instruction:")[0] | |
| if __name__ == "__main__": | |
| while True: | |
| query = input("> ") | |
| result = format_query(query) | |
| print(f"Result: {result}") | |
| print("="*100) |