Spaces:
Runtime error
Runtime error
| import torch | |
| from datasets import Dataset as hfd | |
| from datasets import load_dataset | |
| from sentence_transformers import SentenceTransformer | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| BitsAndBytesConfig, | |
| pipeline, | |
| ) | |
| from config import DATASET_HF_NAME, LLAMA3_CHECKPOINT | |
| # Adapted from HF https://huggingface.co/blog/not-lain/rag-chatbot-using-llama3 | |
| def search_topk( | |
| data: hfd, | |
| feature_extractor: SentenceTransformer, | |
| query: str, | |
| k: int = 3, | |
| embedding_col: str = "embedding", | |
| ): | |
| """a function that embeds a new query and returns the most probable results""" | |
| embedded_query = feature_extractor.encode(query) # embed new query | |
| scores, retrieved_examples = data.get_nearest_examples( # retrieve results | |
| embedding_col, | |
| embedded_query, # compare our new embedded query with the dataset embeddings | |
| k=k, # get only top k results | |
| ) | |
| return scores, retrieved_examples | |
| def format_prompt( | |
| prompt: str, retrieved_documents: hfd, k: int, text_col: str = "chunk" | |
| ): | |
| """using the retrieved documents we will prompt the model to generate our responses""" | |
| PROMPT = f"Question:{prompt}\nContext:" | |
| for idx in range(k): | |
| PROMPT += f"{retrieved_documents[text_col][idx]}\n" | |
| return PROMPT | |
| # Quantization Config | |
| #bnb_config = BitsAndBytesConfig( | |
| # load_in_4bit=True, | |
| # bnb_4bit_use_double_quant=True, | |
| # bnb_4bit_quant_type="nf4", | |
| # bnb_4bit_compute_dtype=torch.bfloat16, | |
| #) | |
| bnb_config=BitsAndBytesConfig(load_in_8bit=True, bnb_4bit_compute_dtype=torch.bfloat16) | |
| # Tokenizer & Model | |
| # You must request access to the checkpoints | |
| TOKENIZER = AutoTokenizer.from_pretrained(LLAMA3_CHECKPOINT) | |
| MODEL = AutoModelForCausalLM.from_pretrained( | |
| LLAMA3_CHECKPOINT, | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto", | |
| quantization_config=bnb_config, | |
| ) | |
| TERMINATORS = [TOKENIZER.eos_token_id, TOKENIZER.convert_tokens_to_ids("<|eot_id|>")] | |
| DATA = load_dataset(DATASET_HF_NAME)["train"] | |
| TEXT_GENERATION_PIPELINE = pipeline( | |
| model=MODEL, | |
| tokenizer=TOKENIZER, | |
| task="text-generation", | |
| device_map="auto", | |
| ) | |
| TEXT_GENERATION_PIPELINE.tokenizer | |
| PIPELINE_INFERENCE_ARGS = { | |
| "max_new_tokens": 256, | |
| "eos_token_id": TERMINATORS, | |
| "do_sample": True, | |
| "temperature": 0.1, | |
| "top_p": 0.9, | |
| } | |