Spaces:
Sleeping
Sleeping
| import torch | |
| from datetime import datetime | |
| from threading import Thread | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer | |
| from typing import List | |
| from app.config import config | |
| SYSTEM_PROMPT = ( | |
| "You are a helpful multilingual assistant supporting Yoruba, Igbo, Hausa, and English. " | |
| "Respond clearly and directly. Do not repeat the user's question. " | |
| "Do not include any prompt, preamble, or meta-commentary in your response. " | |
| "You are called FarmLingua AI that is you name. " | |
| "You are an agricultural expert on all kinds of farming. " | |
| "You are here to help Nigerian farmers." | |
| "do not answer question on any other field if it not about farming , farm business." | |
| "users might not complete a question always understand smartly the intent of the question and answer." | |
| "so always tell the users after answering their questions that kawafarm LTD built you for farmers." | |
| "make sure you answer the users questions very well ." | |
| ) | |
| class LLMEngine: | |
| def __init__(self): | |
| token = config.HF_TOKEN or None | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| config.LLM_MODEL, | |
| token=token, | |
| ) | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| config.LLM_MODEL, | |
| dtype=torch.float16, | |
| device_map="auto", | |
| token=token, | |
| # N-ATLaS ships sharded safetensors — do not set use_safetensors=False | |
| ) | |
| self.model.eval() | |
| def _build_messages(self, history: List[dict], user_message: str) -> List[dict]: | |
| messages = [{"role": "system", "content": SYSTEM_PROMPT}] | |
| messages.extend(history) | |
| messages.append({"role": "user", "content": user_message}) | |
| return messages | |
| def _format_prompt(self, messages: List[dict]) -> str: | |
| return self.tokenizer.apply_chat_template( | |
| messages, | |
| add_generation_prompt=True, | |
| tokenize=False, | |
| date_string=datetime.now().strftime("%d %b %Y"), | |
| ) | |
| def stream(self, history: List[dict], user_message: str): | |
| messages = self._build_messages(history, user_message) | |
| prompt = self._format_prompt(messages) | |
| inputs = self.tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| add_special_tokens=False, | |
| ).to(self.model.device) | |
| streamer = TextIteratorStreamer( | |
| self.tokenizer, | |
| skip_prompt=True, | |
| skip_special_tokens=True, | |
| ) | |
| generation_kwargs = dict( | |
| **inputs, | |
| streamer=streamer, | |
| max_new_tokens=config.LLM_MAX_NEW_TOKENS, | |
| temperature=config.LLM_TEMPERATURE, | |
| repetition_penalty=config.LLM_REPETITION_PENALTY, | |
| use_cache=True, | |
| do_sample=True, | |
| ) | |
| thread = Thread(target=self.model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| return streamer | |
| llm_engine = LLMEngine() |