| import streamlit as st |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig,HfArgumentParser,TrainingArguments,pipeline, logging |
| import torch |
|
|
| base_model = "minhtt/vistral-7b-chat" |
|
|
| bnb_config = BitsAndBytesConfig( |
| load_in_4bit= True, |
| bnb_4bit_quant_type= "nf4", |
| bnb_4bit_compute_dtype= torch.bfloat16, |
| bnb_4bit_use_double_quant= False, |
| ) |
|
|
| model = AutoModelForCausalLM.from_pretrained( |
| base_model, |
| load_in_4bit=True, |
| quantization_config=bnb_config, |
| torch_dtype=torch.bfloat16, |
| device_map="auto", |
| trust_remote_code=True, |
| ) |
|
|
|
|
| model.config.use_cache = False |
| model.config.pretraining_tp = 1 |
| model.gradient_checkpointing_enable() |
|
|
| |
| tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True) |
| tokenizer.padding_side = 'right' |
| tokenizer.pad_token = tokenizer.eos_token |
| tokenizer.add_eos_token = True |
| tokenizer.bos_token, tokenizer.eos_token |
|
|
| pipe = pipeline(task="text-generation", model=model, tokenizer=tokenizer, max_length=200) |
| text = st.text_erea("Đặt câu hỏi") |
|
|
| if text: |
| out = pipe(text) |
| st.text_erea(out) |