| | |
| | |
| | |
| | from llama_index.llms.huggingface_api import HuggingFaceInferenceAPI |
| | import torch |
| | |
| | |
| |
|
| | def setup_llm(model_name: str = "microsoft/phi-3-mini-4k-instruct", |
| | device: str = None, |
| | context_window: int = 4096, |
| | max_new_tokens: int = 512): |
| | """Set up the language model for the CSV chatbot.""" |
| | |
| | if device is None: |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | |
| | |
| | try: |
| | |
| | from llama_index.llms.huggingface import HuggingFaceLLM |
| | |
| | |
| | model_kwargs = { |
| | "trust_remote_code": True, |
| | "torch_dtype": torch.float16, |
| | } |
| | |
| | if device == "cuda": |
| | from transformers import BitsAndBytesConfig |
| | quantization_config = BitsAndBytesConfig( |
| | load_in_4bit=True, |
| | bnb_4bit_compute_dtype=torch.float16 |
| | ) |
| | model_kwargs["quantization_config"] = quantization_config |
| | |
| | |
| | llm = HuggingFaceLLM( |
| | model_name=model_name, |
| | tokenizer_name=model_name, |
| | context_window=context_window, |
| | max_new_tokens=max_new_tokens, |
| | generate_kwargs={"temperature": 0.7, "top_p": 0.95}, |
| | device_map=device, |
| | tokenizer_kwargs={"trust_remote_code": True}, |
| | model_kwargs=model_kwargs, |
| | |
| | cache_folder="./model_cache" |
| | ) |
| | |
| | except (ImportError, AttributeError): |
| | |
| | try: |
| | from llama_index.llms import HuggingFaceInferenceAPI |
| | |
| | llm = HuggingFaceInferenceAPI( |
| | model_name=model_name, |
| | tokenizer_name=model_name, |
| | context_window=context_window, |
| | max_new_tokens=max_new_tokens, |
| | generate_kwargs={"temperature": 0.7, "top_p": 0.95} |
| | ) |
| | except: |
| | |
| | from llama_index.llms.base import LLM |
| | from llama_index.llms.huggingface import HuggingFaceInference |
| | |
| | llm = HuggingFaceInference( |
| | model_name=model_name, |
| | tokenizer_name=model_name, |
| | context_window=context_window, |
| | max_new_tokens=max_new_tokens, |
| | generate_kwargs={"temperature": 0.7, "top_p": 0.95} |
| | ) |
| | |
| | return llm |
| |
|