C4HCyberAI-1o / llm_engine.py
CRX7's picture
Create llm_engine.py
7bc8956 verified
raw
history blame contribute delete
758 Bytes
from transformers import AutoModelForCausalLM, AutoTokenizer
class ChatEngine:
def __init__(self, model_name: str, system_prompt: str):
self.model_name = model_name
self.system_prompt = system_prompt
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(model_name)
def _generate(self, prompt: str) -> str:
inputs = self.tokenizer(prompt, return_tensors="pt")
outputs = self.model.generate(**inputs, max_new_tokens=512)
return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
def chat(self, user_input: str) -> str:
prompt = f"{self.system_prompt}\nUser: {user_input}\nAI:"
return self._generate(prompt)