Update app.py
Browse files
app.py
CHANGED
|
@@ -7,21 +7,27 @@ from langchain.vectorstores import FAISS
|
|
| 7 |
from langchain.chains import RetrievalQA
|
| 8 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
|
| 9 |
import os
|
| 10 |
-
|
| 11 |
# Load FLAN-T5 model
|
| 12 |
@st.cache_resource
|
| 13 |
def load_llm():
|
| 14 |
-
model_name = "
|
| 15 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 16 |
-
model =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
pipe = pipeline(
|
| 18 |
-
"
|
| 19 |
model=model,
|
| 20 |
tokenizer=tokenizer,
|
| 21 |
-
max_length=512,
|
| 22 |
-
temperature=0.7, #
|
| 23 |
top_p=0.95,
|
| 24 |
-
repetition_penalty=1.15
|
|
|
|
| 25 |
)
|
| 26 |
return HuggingFacePipeline(pipeline=pipe)
|
| 27 |
|
|
|
|
| 7 |
from langchain.chains import RetrievalQA
|
| 8 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
|
| 9 |
import os
|
| 10 |
+
import torch
|
| 11 |
# Load FLAN-T5 model
|
| 12 |
@st.cache_resource
|
| 13 |
def load_llm():
|
| 14 |
+
model_name = "tiiuae/falcon-7b" # Thay bằng tên mô hình bạn chọn
|
| 15 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 16 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 17 |
+
model_name,
|
| 18 |
+
torch_dtype=torch.bfloat16, # Giảm bộ nhớ nếu có GPU hỗ trợ
|
| 19 |
+
trust_remote_code=True,
|
| 20 |
+
device_map="auto" # Tự động phân bổ lên GPU/CPU
|
| 21 |
+
)
|
| 22 |
pipe = pipeline(
|
| 23 |
+
"text-generation",
|
| 24 |
model=model,
|
| 25 |
tokenizer=tokenizer,
|
| 26 |
+
max_length=512, # Độ dài tối đa đầu ra
|
| 27 |
+
temperature=0.7, # Độ sáng tạo
|
| 28 |
top_p=0.95,
|
| 29 |
+
repetition_penalty=1.15,
|
| 30 |
+
do_sample=True
|
| 31 |
)
|
| 32 |
return HuggingFacePipeline(pipeline=pipe)
|
| 33 |
|