sharath6900 commited on
Commit
471ecfa
·
verified ·
1 Parent(s): cdc817f

streamlit.app

Browse files
Files changed (1) hide show
  1. app.py +60 -0
app.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ %%writefile app.py
2
+ import streamlit as st
3
+ import torch
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
5
+ from peft import LoraConfig
6
+ from trl import SFTTrainer
7
+ from datasets import load_dataset
8
+
9
+ # Define Streamlit interface
10
+ st.title("Llama-2-7b-Chat Fine-Tuned Model")
11
+ st.write("This app demonstrates a fine-tuned Llama-2-7b model using QLoRA.")
12
+
13
+ # Input text prompt
14
+ prompt = st.text_input("Enter your prompt:", value="What is opensource llm?")
15
+
16
+ # Model settings
17
+ st.write("Loading the model...")
18
+
19
+ # Load model and tokenizer
20
+ model_name = "NousResearch/Llama-2-7b-chat-hf"
21
+ dataset_name = "mlabonne/guanaco-llama2-1k"
22
+
23
+ # QLoRA parameters
24
+ lora_r = 64
25
+ lora_alpha = 16
26
+ lora_dropout = 0.1
27
+ use_4bit = True
28
+ bnb_4bit_compute_dtype = "float16"
29
+ bnb_4bit_quant_type = "nf4"
30
+ use_nested_quant = False
31
+
32
+ compute_dtype = getattr(torch, bnb_4bit_compute_dtype)
33
+
34
+ bnb_config = BitsAndBytesConfig(
35
+ load_in_4bit=use_4bit,
36
+ bnb_4bit_quant_type=bnb_4bit_quant_type,
37
+ bnb_4bit_compute_dtype=compute_dtype,
38
+ bnb_4bit_use_double_quant=use_nested_quant,
39
+ )
40
+
41
+ device_map = {"": 0}
42
+ model = AutoModelForCausalLM.from_pretrained(
43
+ model_name,
44
+ quantization_config=bnb_config,
45
+ device_map=device_map
46
+ )
47
+
48
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
49
+ tokenizer.pad_token = tokenizer.eos_token
50
+ tokenizer.padding_side = "right"
51
+
52
+ # Run inference
53
+ if st.button("Generate"):
54
+ pipe = pipeline(task="text-generation", model=model, tokenizer=tokenizer, max_length=200)
55
+ result = pipe(f"<s>[INST] {prompt} [/INST]")
56
+ st.write(result[0]['generated_text'])
57
+
58
+ prompt = "What is open-source LLM?"
59
+ print(generate_text(prompt))
60
+