Spaces:
Runtime error
Runtime error
| from pathlib import Path | |
| from typing import Any, Callable, Dict, Optional | |
| from langchain.callbacks.base import BaseCallbackHandler | |
| from langchain.llms import CTransformers, HuggingFacePipeline | |
| from langchain.llms.base import LLM | |
| from .utils import merge | |
| def get_gptq_llm(config: Dict[str, Any]) -> LLM: | |
| try: | |
| from auto_gptq import AutoGPTQForCausalLM | |
| except ImportError: | |
| raise ImportError( | |
| "Could not import `auto_gptq` package. " | |
| "Please install it with `pip install chatdocs[gptq]`" | |
| ) | |
| from transformers import ( | |
| AutoTokenizer, | |
| TextGenerationPipeline, | |
| MODEL_FOR_CAUSAL_LM_MAPPING, | |
| ) | |
| local_files_only = not config["download"] | |
| config = {**config["gptq"]} | |
| model_name_or_path = config.pop("model") | |
| model_file = config.pop("model_file", None) | |
| pipeline_kwargs = config.pop("pipeline_kwargs", None) or {} | |
| model_basename = None | |
| use_safetensors = False | |
| if model_file: | |
| model_basename = Path(model_file).stem | |
| use_safetensors = model_file.endswith(".safetensors") | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_name_or_path, | |
| local_files_only=local_files_only, | |
| ) | |
| model = AutoGPTQForCausalLM.from_quantized( | |
| model_name_or_path, | |
| model_basename=model_basename, | |
| use_safetensors=use_safetensors, | |
| local_files_only=local_files_only, | |
| **config, | |
| ) | |
| MODEL_FOR_CAUSAL_LM_MAPPING.register("chatdocs-gptq", model.__class__) | |
| pipeline = TextGenerationPipeline( | |
| task="text-generation", | |
| model=model, | |
| tokenizer=tokenizer, | |
| **pipeline_kwargs, | |
| ) | |
| return HuggingFacePipeline(pipeline=pipeline) | |
| def get_llm( | |
| config: Dict[str, Any], | |
| *, | |
| callback: Optional[Callable[[str], None]] = None, | |
| ) -> LLM: | |
| class CallbackHandler(BaseCallbackHandler): | |
| def on_llm_new_token(self, token: str, **kwargs) -> None: | |
| callback(token) | |
| callbacks = [CallbackHandler()] if callback else None | |
| local_files_only = not config["download"] | |
| if config["llm"] == "ctransformers": | |
| config = {**config["ctransformers"]} | |
| config = merge(config, {"config": {"local_files_only": local_files_only}}) | |
| llm = CTransformers(callbacks=callbacks, **config) | |
| elif config["llm"] == "gptq": | |
| llm = get_gptq_llm(config) | |
| else: | |
| config = {**config["huggingface"]} | |
| config["model_id"] = config.pop("model") | |
| config = merge(config, {"model_kwargs": {"local_files_only": local_files_only}}) | |
| llm = HuggingFacePipeline.from_model_id(task="text-generation", **config) | |
| return llm | |