File size: 622 Bytes
ce80d21
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
from llm.llm_base import LLMBase
from transformers import pipeline

class LocalLLM(LLMBase):
    def __init__(self, model_name="TinyLlama/TinyLlama-1.1B-Chat-v1.0", device="cpu", **kwargs):
        self.pipe = pipeline("text-generation", model=model_name, device=device, **kwargs)
    def generate(self, prompt: str, max_new_tokens: int = 200, **kwargs) -> str:
        result = self.pipe(prompt, max_new_tokens=max_new_tokens)
        return result[0]["generated_text"]

def get_llm(model_name: str = "TinyLlama/TinyLlama-1.1B-Chat-v1.0", **kwargs):
    
    return LocalLLM(model_name=model_name, **kwargs)