Spaces:
Runtime error
Runtime error
| from typing import Any, List, Mapping, Optional | |
| from langchain.callbacks.manager import CallbackManagerForLLMRun | |
| from langchain.llms.base import LLM | |
| from airllm import AirLLMLlama2 | |
| class AirLLM(LLM): | |
| max_len: int | |
| model: AirLLMLlama2 | |
| def __init__(self, llama2_model_id : str, max_len : int, compression = ""): | |
| # could use hugging face model repo id: | |
| self.model = AirLLMLlama2(llama2_model_id)#,compression=compression | |
| self.max_len = max_len | |
| def _llm_type(self) -> str: | |
| return "custom" | |
| def _call( | |
| self, | |
| prompt: str, | |
| stop: Optional[List[str]] = None, | |
| run_manager: Optional[CallbackManagerForLLMRun] = None, | |
| **kwargs: Any, | |
| ) -> str: | |
| if stop is not None: | |
| raise ValueError("stop kwargs are not permitted.") | |
| input_tokens = model.tokenizer(input_text, | |
| return_tensors="pt", | |
| return_attention_mask=False, | |
| truncation=True, | |
| max_length=self.max_len, | |
| padding=True) | |
| generation_output = model.generate( | |
| input_tokens['input_ids'].cuda(), | |
| max_new_tokens=20, | |
| use_cache=True, | |
| return_dict_in_generate=True) | |
| output = model.tokenizer.decode(generation_output.sequences[0]) | |
| return output | |
| def _identifying_params(self) -> Mapping[str, Any]: | |
| """Get the identifying parameters.""" | |
| return {"max_len": self.max_len} | |