drrobot9's picture
Initial commit
67367c9 verified
import logging
from threading import Thread
from typing import Generator, Dict, Any, List
import re
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
if torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
logger = logging.getLogger("plutus.model")
logging.basicConfig(level=logging.INFO)
MODEL_NAME = "Remostart/Plutus_Advanced_model"
class SharedLLM:
_tokenizer = None
_model = None
_device = "cuda" if torch.cuda.is_available() else "cpu"
@classmethod
def load(cls):
if cls._model is not None:
return cls._tokenizer, cls._model, cls._device
logger.info(f"[LOAD] Loading tokenizer: {MODEL_NAME}")
cls._tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
logger.info(f"[LOAD] Loading model on {cls._device}")
cls._model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float16 if cls._device == "cuda" else None,
low_cpu_mem_usage=True
)
cls._model.to(cls._device)
cls._model.eval()
logger.info("[READY] Shared LLM loaded once.")
return cls._tokenizer, cls._model, cls._device
_SENTENCE_END_RE = re.compile(r"([.!?])\s+$")
_LIST_ITEM_RE = re.compile(r"^\s*(\d+\.|\-|\*)\s+$")
_CODE_FENCE = "```"
def should_flush(buffer: str) -> bool:
stripped = buffer.strip()
if len(stripped) < 25:
return False
if _LIST_ITEM_RE.match(stripped):
return False
if "\n\n" in buffer:
return True
if _SENTENCE_END_RE.search(buffer):
return True
if len(buffer) > 180:
return True
return False
class PlutusModel:
def __init__(self):
self.tokenizer, self.model, self.device = SharedLLM.load()
def create_prompt(self, personality: str, level: str, topic: str, extra_context: str = None) -> str:
prompt = (
"You are PlutusTutor — the best expert in Cardano's Plutus ecosystem.\n\n"
f"User Info:\n"
f"- Personality: {personality}\n"
f"- Level: {level}\n"
f"- Topic: {topic}\n\n"
"Your task:\n"
"- Teach with extreme clarity.\n"
"- Use structured explanations.\n"
"- Include examples where helpful.\n"
"- Avoid filler.\n"
"- Adapt tone to personality.\n\n"
)
if extra_context:
prompt += f"Additional Context:\n{extra_context}\n\n"
return prompt + "Begin teaching now.\n\nAssistant:"
def generate(
self,
prompt: str,
max_new_tokens: int = 600,
temperature: float = 0.6,
top_p: float = 0.9
) -> Generator[str, None, None]:
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True)
def _run():
with torch.inference_mode():
self.model.generate(
**inputs,
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.pad_token_id,
)
Thread(target=_run, daemon=True).start()
buffer = ""
in_code_block = False
for token in streamer:
buffer += token
if _CODE_FENCE in buffer:
in_code_block = not in_code_block
if not in_code_block and should_flush(buffer):
yield buffer.strip()
buffer = ""
if buffer.strip():
yield buffer.strip()
class SummaryModel:
def __init__(self):
self.tokenizer, self.model, self.device = SharedLLM.load()
def summarize_text(
self,
full_teaching: str,
topic: str,
level: str,
recommended: List[Dict[str, Any]],
max_new_tokens: int = 400
) -> Generator[str, None, None]:
prompt = (
"You are a world-class summarization assistant.\n\n"
f"TOPIC: {topic}\n"
f"LEVEL: {level}\n\n"
"CONTENT:\n"
f"{full_teaching}\n\n"
"Produce a clear, structured summary.\n\n"
"Assistant:"
)
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True)
def _run():
with torch.inference_mode():
self.model.generate(
**inputs,
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=0.6,
top_p=0.9,
eos_token_id=self.tokenizer.eos_token_id,
)
Thread(target=_run, daemon=True).start()
buffer = ""
in_code_block = False
for token in streamer:
buffer += token
if _CODE_FENCE in buffer:
in_code_block = not in_code_block
if not in_code_block and should_flush(buffer):
yield buffer.strip()
buffer = ""
if buffer.strip():
yield buffer.strip()
if recommended:
yield "\n\n### Recommended Resources\n"
for item in recommended:
line = f"- **{item['type'].upper()}**: {item.get('url')}"
yield line