| import torch |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig |
| from peft import PeftModel, PeftConfig |
| import gradio as gr |
|
|
| model_repo = "nambn0321/LLM_model" |
|
|
| |
| peft_config = PeftConfig.from_pretrained(model_repo) |
|
|
| bnb_config = BitsAndBytesConfig( |
| load_in_4bit=True, |
| bnb_4bit_quant_type="nf4", |
| bnb_4bit_use_double_quant=True, |
| bnb_4bit_compute_dtype=torch.float32 |
| ) |
|
|
| base_model = AutoModelForCausalLM.from_pretrained( |
| peft_config.base_model_name_or_path, |
| quantization_config=bnb_config, |
| device_map="auto", |
| trust_remote_code=True, |
| offload_folder="./offload" |
| ) |
|
|
| |
| model = PeftModel.from_pretrained(base_model, model_repo) |
|
|
| |
| tokenizer = AutoTokenizer.from_pretrained(model_repo, use_fast=False) |
|
|
| def generate_response(prompt, max_tokens=128, temperature=0.7, top_p=0.9): |
| try: |
| chat = [{"role": "user", "content": prompt}] |
| formatted_prompt = tokenizer.apply_chat_template( |
| chat, tokenize=False, add_generation_prompt=True |
| ) |
|
|
| inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device) |
|
|
| with torch.no_grad(): |
| outputs = model.generate( |
| **inputs, |
| max_new_tokens=max_tokens, |
| temperature=temperature, |
| top_p=top_p, |
| do_sample=True, |
| eos_token_id=tokenizer.eos_token_id, |
| use_cache=False |
| ) |
|
|
| decoded = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
| |
| if "<|assistant|>" in decoded: |
| decoded = decoded.split("<|assistant|>")[-1].strip() |
|
|
| return decoded |
| except Exception as e: |
| return f"Error: {str(e)}" |
|
|
| iface = gr.Interface( |
| fn=generate_response, |
| inputs=[ |
| gr.Textbox(lines=4, label="Prompt"), |
| gr.Slider(16, 512, value=128, step=16, label="Max Tokens"), |
| gr.Slider(0.1, 1.5, value=0.7, label="Temperature"), |
| gr.Slider(0.1, 1.0, value=0.9, label="Top-p") |
| ], |
| outputs="text", |
| title="Fine-Tuned LLM", |
| description="Interact with my fine-tuned LLM." |
| ) |
|
|
| iface.launch() |
|
|