Spaces:
Sleeping
Sleeping
File size: 1,625 Bytes
91394e0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
from abc import ABC, abstractmethod
from transformers import pipeline
LLM_MODEL_REGISTRY = {}
class AbstractLLMModel(ABC):
@abstractmethod
def __init__(
self, model_id: str, device: str = "cpu", cache_dir: str = "cache", **kwargs
): ...
@abstractmethod
def generate(self, prompt: str, **kwargs) -> str:
pass
def register_llm_model(prefix: str):
def wrapper(cls):
assert issubclass(cls, AbstractLLMModel), f"{cls} must inherit AbstractLLMModel"
LLM_MODEL_REGISTRY[prefix] = cls
return cls
return wrapper
def get_llm_model(model_id: str, device="cpu", **kwargs) -> AbstractLLMModel:
for prefix, cls in LLM_MODEL_REGISTRY.items():
if model_id.startswith(prefix):
return cls(model_id, device=device, **kwargs)
raise ValueError(f"No LLM wrapper found for model: {model_id}")
@register_llm_model("google/gemma")
@register_llm_model("tii/") # e.g., Falcon
@register_llm_model("meta-llama")
class HFTextGenerationLLM(AbstractLLMModel):
def __init__(
self, model_id: str, device: str = "cpu", cache_dir: str = "cache", **kwargs
):
model_kwargs = kwargs.setdefault("model_kwargs", {})
model_kwargs["cache_dir"] = cache_dir
self.pipe = pipeline(
"text-generation",
model=model_id,
device=0 if device == "cuda" else -1,
return_full_text=False,
**kwargs,
)
def generate(self, prompt: str, **kwargs) -> str:
outputs = self.pipe(prompt, **kwargs)
return outputs[0]["generated_text"] if outputs else ""
|