import logging from threading import Thread from typing import Generator, Dict, Any, List import re import torch from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer if torch.cuda.is_available(): torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True logger = logging.getLogger("plutus.model") logging.basicConfig(level=logging.INFO) MODEL_NAME = "Remostart/Plutus_Advanced_model" class SharedLLM: _tokenizer = None _model = None _device = "cuda" if torch.cuda.is_available() else "cpu" @classmethod def load(cls): if cls._model is not None: return cls._tokenizer, cls._model, cls._device logger.info(f"[LOAD] Loading tokenizer: {MODEL_NAME}") cls._tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True) logger.info(f"[LOAD] Loading model on {cls._device}") cls._model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, torch_dtype=torch.float16 if cls._device == "cuda" else None, low_cpu_mem_usage=True ) cls._model.to(cls._device) cls._model.eval() logger.info("[READY] Shared LLM loaded once.") return cls._tokenizer, cls._model, cls._device _SENTENCE_END_RE = re.compile(r"([.!?])\s+$") _LIST_ITEM_RE = re.compile(r"^\s*(\d+\.|\-|\*)\s+$") _CODE_FENCE = "```" def should_flush(buffer: str) -> bool: stripped = buffer.strip() if len(stripped) < 25: return False if _LIST_ITEM_RE.match(stripped): return False if "\n\n" in buffer: return True if _SENTENCE_END_RE.search(buffer): return True if len(buffer) > 180: return True return False class PlutusModel: def __init__(self): self.tokenizer, self.model, self.device = SharedLLM.load() def create_prompt(self, personality: str, level: str, topic: str, extra_context: str = None) -> str: prompt = ( "You are PlutusTutor — the best expert in Cardano's Plutus ecosystem.\n\n" f"User Info:\n" f"- Personality: {personality}\n" f"- Level: {level}\n" f"- Topic: {topic}\n\n" "Your task:\n" "- Teach with extreme clarity.\n" "- Use structured explanations.\n" "- Include examples where helpful.\n" "- Avoid filler.\n" "- Adapt tone to personality.\n\n" ) if extra_context: prompt += f"Additional Context:\n{extra_context}\n\n" return prompt + "Begin teaching now.\n\nAssistant:" def generate( self, prompt: str, max_new_tokens: int = 600, temperature: float = 0.6, top_p: float = 0.9 ) -> Generator[str, None, None]: inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True) def _run(): with torch.inference_mode(): self.model.generate( **inputs, streamer=streamer, max_new_tokens=max_new_tokens, do_sample=True, temperature=temperature, top_p=top_p, eos_token_id=self.tokenizer.eos_token_id, pad_token_id=self.tokenizer.pad_token_id, ) Thread(target=_run, daemon=True).start() buffer = "" in_code_block = False for token in streamer: buffer += token if _CODE_FENCE in buffer: in_code_block = not in_code_block if not in_code_block and should_flush(buffer): yield buffer.strip() buffer = "" if buffer.strip(): yield buffer.strip() class SummaryModel: def __init__(self): self.tokenizer, self.model, self.device = SharedLLM.load() def summarize_text( self, full_teaching: str, topic: str, level: str, recommended: List[Dict[str, Any]], max_new_tokens: int = 400 ) -> Generator[str, None, None]: prompt = ( "You are a world-class summarization assistant.\n\n" f"TOPIC: {topic}\n" f"LEVEL: {level}\n\n" "CONTENT:\n" f"{full_teaching}\n\n" "Produce a clear, structured summary.\n\n" "Assistant:" ) inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True) def _run(): with torch.inference_mode(): self.model.generate( **inputs, streamer=streamer, max_new_tokens=max_new_tokens, do_sample=True, temperature=0.6, top_p=0.9, eos_token_id=self.tokenizer.eos_token_id, ) Thread(target=_run, daemon=True).start() buffer = "" in_code_block = False for token in streamer: buffer += token if _CODE_FENCE in buffer: in_code_block = not in_code_block if not in_code_block and should_flush(buffer): yield buffer.strip() buffer = "" if buffer.strip(): yield buffer.strip() if recommended: yield "\n\n### Recommended Resources\n" for item in recommended: line = f"- **{item['type'].upper()}**: {item.get('url')}" yield line