jslin09 commited on
Commit
925c10d
·
1 Parent(s): 871c8ac

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -163
app.py DELETED
@@ -1,163 +0,0 @@
1
- import torch
2
- from peft import PeftModel, PeftConfig
3
- import transformers
4
- import gradio as gr
5
-
6
- from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, BloomForCausalLM, GenerationConfig
7
- from transformers.models.opt.modeling_opt import OPTDecoderLayer
8
-
9
- tokenizer = AutoTokenizer.from_pretrained('bigscience/bloom')
10
-
11
- BASE_MODEL = "bigscience/bloom-3b"
12
-
13
- #LORA_WEIGHTS = f"/content/drive/MyDrive/Colab Notebooks/LegalChatbot-{model_name}"
14
- LORA_WEIGHTS = f"jslin09/LegalChatbot-bloom-3b"
15
-
16
- config = PeftConfig.from_pretrained(LORA_WEIGHTS)
17
-
18
- if torch.cuda.is_available():
19
- device = "cuda"
20
- else:
21
- device = "cpu"
22
-
23
- try:
24
- if torch.backends.mps.is_available():
25
- device = "mps"
26
- except:
27
- pass
28
-
29
- if device == "cuda":
30
- model = BloomForCausalLM.from_pretrained(
31
- BASE_MODEL,
32
- load_in_8bit=True,
33
- torch_dtype=torch.float16,
34
- device_map="auto",
35
- )
36
- model = PeftModel.from_pretrained(model, LORA_WEIGHTS, torch_dtype=torch.float16)
37
- elif device == "mps":
38
- model = BloomForCausalLM.from_pretrained(
39
- BASE_MODEL,
40
- device_map={"": device},
41
- torch_dtype=torch.float16,
42
- )
43
- model = PeftModel.from_pretrained(
44
- model,
45
- LORA_WEIGHTS,
46
- device_map={"": device},
47
- torch_dtype=torch.float16,
48
- )
49
- else:
50
- model = BloomForCausalLM.from_pretrained(
51
- BASE_MODEL, device_map={"": device},
52
- low_cpu_mem_usage=True
53
- )
54
- model = PeftModel.from_pretrained(
55
- model,
56
- LORA_WEIGHTS,
57
- device_map={"": device},
58
- )
59
-
60
-
61
- def generate_prompt(instruction, input=None):
62
- if input:
63
- return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
64
-
65
- ### Instruction:
66
- {instruction}
67
-
68
- ### Input:
69
- {input}
70
-
71
- ### Response:"""
72
- else:
73
- return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
74
-
75
- ### Instruction:
76
- {instruction}
77
-
78
- ### Response:"""
79
-
80
- def generate_prompt_tw(instruction, input=None):
81
- if input:
82
- return f"""以下是描述任務的指令,並與提供進一步上下文的輸入配對。編寫適當完成請求的回應。
83
-
84
- ### 指令:
85
- {instruction}
86
-
87
- ### 輸入:
88
- {input}
89
-
90
- ### 回應:"""
91
- else:
92
- return f"""以下是描述任務的指令。編寫適當完成請求的回應。
93
-
94
- ### 指令:
95
- {instruction}
96
-
97
- ### 回應:"""
98
-
99
-
100
- model.eval()
101
- if torch.__version__ >= "2":
102
- model = torch.compile(model)
103
-
104
-
105
- def evaluate(
106
- instruction,
107
- input=None,
108
- temperature=0.1,
109
- top_p=0.75,
110
- top_k=40,
111
- num_beams=4,
112
- max_new_tokens=128,
113
- **kwargs,
114
- ):
115
- prompt = generate_prompt_tw(instruction, input) # 中文版的話,函數名稱要改用 generate_prompt_tw
116
- inputs = tokenizer(prompt, return_tensors="pt")
117
- input_ids = inputs["input_ids"].to(device)
118
- generation_config = GenerationConfig(
119
- temperature=temperature,
120
- top_p=top_p,
121
- top_k=top_k,
122
- num_beams=num_beams,
123
- **kwargs,
124
- )
125
- with torch.no_grad():
126
- generation_output = model.generate(
127
- input_ids=input_ids,
128
- generation_config=generation_config,
129
- return_dict_in_generate=True,
130
- output_scores=True,
131
- max_new_tokens=max_new_tokens,
132
- )
133
- s = generation_output.sequences[0]
134
- output = tokenizer.decode(s)
135
- # return output.split("### Response:")[1].strip() # 中文版的話,要改為 return output.split("### 回應:")[1].strip()
136
- return output.split("### 回應:")[1].strip()
137
-
138
-
139
- gr.Interface(
140
- fn=evaluate,
141
- inputs=[
142
- gr.components.Textbox(
143
- lines=2, label="Instruction", placeholder="Tell me about alpacas."
144
- ),
145
- gr.components.Textbox(lines=2, label="Input", placeholder="none"),
146
- gr.components.Slider(minimum=0, maximum=1, value=0.1, label="Temperature"),
147
- gr.components.Slider(minimum=0, maximum=1, value=0.75, label="Top p"),
148
- gr.components.Slider(minimum=0, maximum=100, step=1, value=40, label="Top k"),
149
- gr.components.Slider(minimum=1, maximum=4, step=1, value=4, label="Beams"),
150
- gr.components.Slider(
151
- minimum=1, maximum=2000, step=1, value=128, label="Max tokens"
152
- ),
153
- ],
154
- outputs=[
155
- gr.components.Textbox(
156
- lines=5,
157
- label="Output",
158
- )
159
- ],
160
- title="🌲 🌲 🌲 BLOOM-LoRA-LegalChatbot",
161
- description="BLOOM-LoRA-LegalChatbot is a 3B-parameter BLOOM model finetuned to follow instructions. It is trained on the [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca) dataset and my Legal QA dataset, and makes use of the Huggingface BLOOM implementation. For more information, please visit [the project's website](https://github.com/tloen/alpaca-lora).",
162
- ).launch()
163
-