uyen13 commited on
Commit
590475f
·
verified ·
1 Parent(s): c3f97cb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -7
app.py CHANGED
@@ -7,21 +7,27 @@ from langchain.vectorstores import FAISS
7
  from langchain.chains import RetrievalQA
8
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
9
  import os
10
-
11
  # Load FLAN-T5 model
12
  @st.cache_resource
13
  def load_llm():
14
- model_name = "google/flan-t5-base" # Adjust model size if needed
15
  tokenizer = AutoTokenizer.from_pretrained(model_name)
16
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
 
 
 
 
 
17
  pipe = pipeline(
18
- "text2text-generation",
19
  model=model,
20
  tokenizer=tokenizer,
21
- max_length=512,
22
- temperature=0.7, # Adjust for creativity
23
  top_p=0.95,
24
- repetition_penalty=1.15
 
25
  )
26
  return HuggingFacePipeline(pipeline=pipe)
27
 
 
7
  from langchain.chains import RetrievalQA
8
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
9
  import os
10
+ import torch
11
  # Load FLAN-T5 model
12
  @st.cache_resource
13
  def load_llm():
14
+ model_name = "tiiuae/falcon-7b" # Thay bằng tên hình bạn chọn
15
  tokenizer = AutoTokenizer.from_pretrained(model_name)
16
+ model = AutoModelForCausalLM.from_pretrained(
17
+ model_name,
18
+ torch_dtype=torch.bfloat16, # Giảm bộ nhớ nếu có GPU hỗ trợ
19
+ trust_remote_code=True,
20
+ device_map="auto" # Tự động phân bổ lên GPU/CPU
21
+ )
22
  pipe = pipeline(
23
+ "text-generation",
24
  model=model,
25
  tokenizer=tokenizer,
26
+ max_length=512, # Độ dài tối đa đầu ra
27
+ temperature=0.7, # Độ sáng tạo
28
  top_p=0.95,
29
+ repetition_penalty=1.15,
30
+ do_sample=True
31
  )
32
  return HuggingFacePipeline(pipeline=pipe)
33