uyen13 commited on
Commit
ca834bb
·
verified ·
1 Parent(s): 436b7d4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -4
app.py CHANGED
@@ -11,17 +11,18 @@ import torch
11
  # Load FLAN-T5 model
12
  @st.cache_resource
13
  def load_llm():
14
- model_name = "google/flan-t5-xl" # <-- Đã thay bằng FLAN-T5
 
15
  tokenizer = AutoTokenizer.from_pretrained(model_name)
16
 
17
- model = AutoModelForCausalLM.from_pretrained(
18
  model_name,
19
- torch_dtype=torch.float32, # flan-t5 không hỗ trợ bfloat16 trên CPU
20
  device_map="auto"
21
  )
22
 
23
  pipe = pipeline(
24
- "text2text-generation", # <-- Chú ý loại pipeline này dành cho T5
25
  model=model,
26
  tokenizer=tokenizer,
27
  max_new_tokens=256,
@@ -30,6 +31,7 @@ def load_llm():
30
  repetition_penalty=1.15,
31
  do_sample=True
32
  )
 
33
  return HuggingFacePipeline(pipeline=pipe)
34
 
35
  # Process PDF and create vectorstore
 
11
  # Load FLAN-T5 model
12
  @st.cache_resource
13
  def load_llm():
14
+ model_name = "google/flan-t5-xl"
15
+
16
  tokenizer = AutoTokenizer.from_pretrained(model_name)
17
 
18
+ model = AutoModelForSeq2SeqLM.from_pretrained(
19
  model_name,
20
+ torch_dtype=torch.float32, # T5 thường dùng float32 hoặc bfloat16 nếu có GPU hỗ trợ
21
  device_map="auto"
22
  )
23
 
24
  pipe = pipeline(
25
+ "text2text-generation",
26
  model=model,
27
  tokenizer=tokenizer,
28
  max_new_tokens=256,
 
31
  repetition_penalty=1.15,
32
  do_sample=True
33
  )
34
+
35
  return HuggingFacePipeline(pipeline=pipe)
36
 
37
  # Process PDF and create vectorstore