from transformers import AutoModelForCausalLM, AutoTokenizer, LoraFromPretrained import torch # Load the base model and tokenizer base_model_name = "armaniii/mistral-argument-classification/" tokenizer = AutoTokenizer.from_pretrained(base_model_name) base_model = AutoModelForCausalLM.from_pretrained(base_model_name) # Load the LoRA adapter lora_model_name = "mistral_lora" lora_weights = LoraFromPretrained(lora_model_name).to(base_model.device) # Merge the LoRA adapter with the base model merged_model = base_model.merge_lora(lora_weights) # Define your API endpoint @app.post("/generate") def generate(request_body): input_text = request_body["input_text"] ... # Use the merged model to generate output output = merged_model.generate(...) return {"output": output}