Spaces:
Runtime error
Runtime error
| import torch | |
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from peft import PeftModel,PeftConfig | |
| # Configuration | |
| BASE_MODEL = "unsloth/Meta-Llama-3.1-8B-Instruct" | |
| LORA_ADAPTERS = "Khalid02/fine_tuned_law_llama3_8b_lora-adapters" | |
| # Global variables for model and tokenizer | |
| model = None | |
| tokenizer = None | |
| def load_components(): | |
| global model, tokenizer | |
| if model is None or tokenizer is None: | |
| print("Loading model and tokenizer...") | |
| try: | |
| # Load tokenizer from base model | |
| tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) | |
| # Configure 4-bit loading | |
| # bnb_config = BitsAndBytesConfig( | |
| # load_in_4bit=True, | |
| # bnb_4bit_quant_type="nf4", | |
| # bnb_4bit_compute_dtype=torch.float16, | |
| # bnb_4bit_use_double_quant=False, | |
| # ) | |
| # Load base model with correct device mapping | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL, | |
| # quantization_config=bnb_config, | |
| device_map="auto", | |
| torch_dtype="auto", | |
| trust_remote_code=True | |
| ) | |
| # Load LoRA adapters with proper config | |
| config = PeftConfig.from_pretrained(LORA_ADAPTERS) | |
| model = PeftModel.from_pretrained( | |
| base_model, | |
| LORA_ADAPTERS, | |
| device_map="auto", | |
| is_trainable=False # Important for inference | |
| ) | |
| # Merge adapters carefully | |
| model = model.merge_and_unload() | |
| print("Model loaded successfully!") | |
| except Exception as e: | |
| print(f"Error loading model: {str(e)}") | |
| raise | |
| return model, tokenizer | |
| def respond(message, history, system_message, max_tokens, temperature, top_p): | |
| """Handle chat responses using the loaded model""" | |
| global model, tokenizer | |
| try: | |
| # Create conversation history | |
| messages = [{"role": "system", "content": system_message}] | |
| for user_input, bot_response in history: | |
| if user_input: | |
| messages.append({"role": "user", "content": user_input}) | |
| if bot_response: | |
| messages.append({"role": "assistant", "content": bot_response}) | |
| messages.append({"role": "user", "content": message}) | |
| # Format input using chat template | |
| prompt = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| # Generate response | |
| outputs = model.generate( | |
| input_ids=inputs.input_ids, | |
| max_new_tokens=int(max_tokens), | |
| temperature=float(temperature), | |
| top_p=float(top_p), | |
| do_sample=temperature > 0.1, | |
| use_cache=True, | |
| ) | |
| # Decode and return response | |
| response = tokenizer.decode( | |
| outputs[0][inputs.input_ids.shape[1]:], | |
| skip_special_tokens=True | |
| ) | |
| return response | |
| except Exception as e: | |
| return f"Error generating response: {str(e)}" | |
| def create_interface(): | |
| """Create Gradio interface""" | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Fine-tuned Llama 3.1 Legal Assistant") | |
| with gr.Row(): | |
| reload_btn = gr.Button("Reload Model") | |
| status = gr.Textbox(label="Load Status", interactive=False) | |
| chat_interface = gr.ChatInterface( | |
| respond, | |
| additional_inputs=[ | |
| gr.Textbox(value="You are a legal expert chatbot. Provide accurate and helpful legal information.", | |
| label="System message", lines=2), | |
| gr.Slider(1, 2048, value=512, step=1, label="Max new tokens"), | |
| gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature"), | |
| gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p"), | |
| ] | |
| ) | |
| def reload_model(): | |
| global model, tokenizer | |
| try: | |
| model, tokenizer = None, None | |
| load_components() | |
| return "Model reloaded successfully!" | |
| except Exception as e: | |
| return f"Reload failed: {str(e)}" | |
| reload_btn.click(reload_model, outputs=status) | |
| return demo | |
| if __name__ == "__main__": | |
| # Initial model load | |
| load_components() | |
| # Create and launch interface | |
| demo = create_interface() | |
| demo.launch() | |
| # import torch | |
| # import gradio as gr | |
| # from transformers import AutoTokenizer, AutoModelForCausalLM | |
| # from peft import PeftModel | |
| # # Load the base model and LoRA adapters | |
| # BASE_MODEL = "unsloth/Meta-Llama-3.1-8B-Instruct" | |
| # LORA_ADAPTERS = "Khalid02/fine_tuned_law_llama3_8b_lora-adapters" | |
| # def load_model(): | |
| # print("Loading model and tokenizer...") | |
| # try: | |
| # tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) | |
| # base_model = AutoModelForCausalLM.from_pretrained( | |
| # BASE_MODEL, | |
| # device_map="auto", | |
| # torch_dtype="auto", # Explicitly set dtype | |
| # trust_remote_code=True | |
| # ) | |
| # model = PeftModel.from_pretrained(base_model, LORA_ADAPTERS, device_map="auto") | |
| # model = model.merge_and_unload() | |
| # print("Model loaded successfully!") | |
| # return tokenizer, model | |
| # except Exception as e: | |
| # print(f"Error loading model: {str(e)}") | |
| # return None, None | |
| # # Global variables for model and tokenizer | |
| # tokenizer, model = None, None | |
| # def respond(message, history, system_message, max_tokens, temperature, top_p): | |
| # global tokenizer, model | |
| # # Check if model is loaded | |
| # if tokenizer is None or model is None: | |
| # # Try loading model again | |
| # tokenizer, model = load_model() | |
| # if tokenizer is None or model is None: | |
| # return "Failed to load the model. Please check your environment and dependencies." | |
| # try: | |
| # messages = [{"role": "system", "content": system_message}] | |
| # for user_input, bot_response in history: | |
| # if user_input: | |
| # messages.append({"role": "user", "content": user_input}) | |
| # if bot_response: | |
| # messages.append({"role": "assistant", "content": bot_response}) | |
| # messages.append({"role": "user", "content": message}) | |
| # # Format the input for Llama 3.1 | |
| # prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| # inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| # outputs = model.generate( | |
| # input_ids=inputs.input_ids, | |
| # max_new_tokens=int(max_tokens), | |
| # temperature=float(temperature), | |
| # top_p=float(top_p), | |
| # do_sample=temperature > 0.1, | |
| # use_cache=True, | |
| # ) | |
| # response = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) | |
| # return response | |
| # except Exception as e: | |
| # import traceback | |
| # error_details = traceback.format_exc() | |
| # return f"Error generating answer: {str(e)}\n\nDetails: {error_details}" | |
| # # Create the Gradio interface | |
| # def create_interface(): | |
| # with gr.Blocks() as demo: | |
| # with gr.Row(): | |
| # gr.Markdown("# Fine-tuned Llama 3.1 Legal Assistant") | |
| # with gr.Row(): | |
| # with gr.Column(): | |
| # load_button = gr.Button("Reload Model") | |
| # def reload_model(): | |
| # global tokenizer, model | |
| # tokenizer, model = load_model() | |
| # if tokenizer is not None and model is not None: | |
| # return "Model reloaded successfully." | |
| # else: | |
| # return "Failed to reload model." | |
| # load_button.click(reload_model, outputs=gr.Textbox(label="Status")) | |
| # with gr.Row(): | |
| # with gr.Column(scale=4): | |
| # chatbot = gr.ChatInterface( | |
| # respond, | |
| # additional_inputs=[ | |
| # gr.Textbox(value="You are a legal expert chatbot. Provide accurate and helpful legal information.", | |
| # label="System message", lines=2), | |
| # gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"), | |
| # gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature"), | |
| # gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"), | |
| # ], | |
| # ) | |
| # return demo | |
| # if __name__ == "__main__": | |
| # # Load model at startup | |
| # tokenizer, model = load_model() | |
| # # Create and launch interface | |
| # demo = create_interface() | |
| # demo.launch() | |