armaniii's picture
Create app.py
a7e3e36 verified
raw
history blame
793 Bytes
from transformers import AutoModelForCausalLM, AutoTokenizer, LoraFromPretrained
import torch
# Load the base model and tokenizer
base_model_name = "armaniii/mistral-argument-classification/"
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
base_model = AutoModelForCausalLM.from_pretrained(base_model_name)
# Load the LoRA adapter
lora_model_name = "mistral_lora"
lora_weights = LoraFromPretrained(lora_model_name).to(base_model.device)
# Merge the LoRA adapter with the base model
merged_model = base_model.merge_lora(lora_weights)
# Define your API endpoint
@app.post("/generate")
def generate(request_body):
input_text = request_body["input_text"]
...
# Use the merged model to generate output
output = merged_model.generate(...)
return {"output": output}