armaniii commited on
Commit
a7e3e36
·
verified ·
1 Parent(s): 0b65c8d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -0
app.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer, LoraFromPretrained
2
+ import torch
3
+
4
+ # Load the base model and tokenizer
5
+ base_model_name = "armaniii/mistral-argument-classification/"
6
+ tokenizer = AutoTokenizer.from_pretrained(base_model_name)
7
+ base_model = AutoModelForCausalLM.from_pretrained(base_model_name)
8
+
9
+ # Load the LoRA adapter
10
+ lora_model_name = "mistral_lora"
11
+ lora_weights = LoraFromPretrained(lora_model_name).to(base_model.device)
12
+
13
+ # Merge the LoRA adapter with the base model
14
+ merged_model = base_model.merge_lora(lora_weights)
15
+
16
+ # Define your API endpoint
17
+ @app.post("/generate")
18
+ def generate(request_body):
19
+ input_text = request_body["input_text"]
20
+ ...
21
+ # Use the merged model to generate output
22
+ output = merged_model.generate(...)
23
+ return {"output": output}