import torch from datetime import datetime from threading import Thread from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer from typing import List from app.config import config SYSTEM_PROMPT = ( "You are a helpful multilingual assistant supporting Yoruba, Igbo, Hausa, and English. " "Respond clearly and directly. Do not repeat the user's question. " "Do not include any prompt, preamble, or meta-commentary in your response. " "You are called FarmLingua AI that is you name. " "You are an agricultural expert on all kinds of farming. " "You are here to help Nigerian farmers." "do not answer question on any other field if it not about farming , farm business." "users might not complete a question always understand smartly the intent of the question and answer." "so always tell the users after answering their questions that kawafarm LTD built you for farmers." "make sure you answer the users questions very well ." ) class LLMEngine: def __init__(self): token = config.HF_TOKEN or None self.tokenizer = AutoTokenizer.from_pretrained( config.LLM_MODEL, token=token, ) self.model = AutoModelForCausalLM.from_pretrained( config.LLM_MODEL, dtype=torch.float16, device_map="auto", token=token, # N-ATLaS ships sharded safetensors — do not set use_safetensors=False ) self.model.eval() def _build_messages(self, history: List[dict], user_message: str) -> List[dict]: messages = [{"role": "system", "content": SYSTEM_PROMPT}] messages.extend(history) messages.append({"role": "user", "content": user_message}) return messages def _format_prompt(self, messages: List[dict]) -> str: return self.tokenizer.apply_chat_template( messages, add_generation_prompt=True, tokenize=False, date_string=datetime.now().strftime("%d %b %Y"), ) def stream(self, history: List[dict], user_message: str): messages = self._build_messages(history, user_message) prompt = self._format_prompt(messages) inputs = self.tokenizer( prompt, return_tensors="pt", add_special_tokens=False, ).to(self.model.device) streamer = TextIteratorStreamer( self.tokenizer, skip_prompt=True, skip_special_tokens=True, ) generation_kwargs = dict( **inputs, streamer=streamer, max_new_tokens=config.LLM_MAX_NEW_TOKENS, temperature=config.LLM_TEMPERATURE, repetition_penalty=config.LLM_REPETITION_PENALTY, use_cache=True, do_sample=True, ) thread = Thread(target=self.model.generate, kwargs=generation_kwargs) thread.start() return streamer llm_engine = LLMEngine()