|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, LoraFromPretrained |
|
|
import torch |
|
|
|
|
|
|
|
|
base_model_name = "armaniii/mistral-argument-classification/" |
|
|
tokenizer = AutoTokenizer.from_pretrained(base_model_name) |
|
|
base_model = AutoModelForCausalLM.from_pretrained(base_model_name) |
|
|
|
|
|
|
|
|
lora_model_name = "mistral_lora" |
|
|
lora_weights = LoraFromPretrained(lora_model_name).to(base_model.device) |
|
|
|
|
|
|
|
|
merged_model = base_model.merge_lora(lora_weights) |
|
|
|
|
|
|
|
|
@app.post("/generate") |
|
|
def generate(request_body): |
|
|
input_text = request_body["input_text"] |
|
|
... |
|
|
|
|
|
output = merged_model.generate(...) |
|
|
return {"output": output} |