uyen13 commited on
Commit
436b7d4
·
verified ·
1 Parent(s): 711e09a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -14
app.py CHANGED
@@ -11,25 +11,24 @@ import torch
11
  # Load FLAN-T5 model
12
  @st.cache_resource
13
  def load_llm():
14
- model_name = "tiiuae/falcon-7b" # Thay bằng tên hình bạn chọn
15
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
16
  model = AutoModelForCausalLM.from_pretrained(
17
  model_name,
18
- torch_dtype=torch.bfloat16, # Giảm bộ nhớ nếu GPU hỗ trợ
19
- trust_remote_code=True,
20
- device_map="auto" # Tự động phân bổ lên GPU/CPU
21
  )
 
22
  pipe = pipeline(
23
- "text-generation",
24
- model=model,
25
- tokenizer=tokenizer,
26
- max_new_tokens=256, # Số token mới tối đa được sinh ra
27
- temperature=0.7,
28
- top_p=0.95,
29
- repetition_penalty=1.15,
30
- do_sample=True,
31
- eos_token_id=tokenizer.eos_token_id, # Dừng sinh văn bản khi gặp end-of-sentence
32
- truncation=True # Cho phép cắt bớt nếu đầu vào quá dài
33
  )
34
  return HuggingFacePipeline(pipeline=pipe)
35
 
 
11
  # Load FLAN-T5 model
12
  @st.cache_resource
13
  def load_llm():
14
+ model_name = "google/flan-t5-xl" # <-- Đã thay bằng FLAN-T5
15
  tokenizer = AutoTokenizer.from_pretrained(model_name)
16
+
17
  model = AutoModelForCausalLM.from_pretrained(
18
  model_name,
19
+ torch_dtype=torch.float32, # flan-t5 không hỗ trợ bfloat16 trên CPU
20
+ device_map="auto"
 
21
  )
22
+
23
  pipe = pipeline(
24
+ "text2text-generation", # <-- Chú ý loại pipeline này dành cho T5
25
+ model=model,
26
+ tokenizer=tokenizer,
27
+ max_new_tokens=256,
28
+ temperature=0.7,
29
+ top_p=0.95,
30
+ repetition_penalty=1.15,
31
+ do_sample=True
 
 
32
  )
33
  return HuggingFacePipeline(pipeline=pipe)
34