NEW_FARMLINGUA_AI / llm /engine.py
drrobot9's picture
stt and ttt together commit
3c95cb7 verified
Raw
History Blame Contribute Delete
3 kB
import torch
from datetime import datetime
from threading import Thread
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from typing import List
from app.config import config
SYSTEM_PROMPT = (
"You are a helpful multilingual assistant supporting Yoruba, Igbo, Hausa, and English. "
"Respond clearly and directly. Do not repeat the user's question. "
"Do not include any prompt, preamble, or meta-commentary in your response. "
"You are called FarmLingua AI that is you name. "
"You are an agricultural expert on all kinds of farming. "
"You are here to help Nigerian farmers."
"do not answer question on any other field if it not about farming , farm business."
"users might not complete a question always understand smartly the intent of the question and answer."
"so always tell the users after answering their questions that kawafarm LTD built you for farmers."
"make sure you answer the users questions very well ."
)
class LLMEngine:
def __init__(self):
token = config.HF_TOKEN or None
self.tokenizer = AutoTokenizer.from_pretrained(
config.LLM_MODEL,
token=token,
)
self.model = AutoModelForCausalLM.from_pretrained(
config.LLM_MODEL,
dtype=torch.float16,
device_map="auto",
token=token,
# N-ATLaS ships sharded safetensors — do not set use_safetensors=False
)
self.model.eval()
def _build_messages(self, history: List[dict], user_message: str) -> List[dict]:
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
messages.extend(history)
messages.append({"role": "user", "content": user_message})
return messages
def _format_prompt(self, messages: List[dict]) -> str:
return self.tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=False,
date_string=datetime.now().strftime("%d %b %Y"),
)
def stream(self, history: List[dict], user_message: str):
messages = self._build_messages(history, user_message)
prompt = self._format_prompt(messages)
inputs = self.tokenizer(
prompt,
return_tensors="pt",
add_special_tokens=False,
).to(self.model.device)
streamer = TextIteratorStreamer(
self.tokenizer,
skip_prompt=True,
skip_special_tokens=True,
)
generation_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=config.LLM_MAX_NEW_TOKENS,
temperature=config.LLM_TEMPERATURE,
repetition_penalty=config.LLM_REPETITION_PENALTY,
use_cache=True,
do_sample=True,
)
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
thread.start()
return streamer
llm_engine = LLMEngine()