|
|
import logging |
|
|
from threading import Thread |
|
|
from typing import Generator, Dict, Any, List |
|
|
import torch |
|
|
from transformers import ( |
|
|
AutoTokenizer, |
|
|
AutoModelForCausalLM, |
|
|
TextIteratorStreamer |
|
|
) |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
|
torch.backends.cudnn.allow_tf32 = True |
|
|
|
|
|
logger = logging.getLogger("plutus.model") |
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
|
|
MODEL_NAME = "Remostart/Plutus_Advanced_model" |
|
|
|
|
|
|
|
|
|
|
|
class SharedLLM: |
|
|
_tokenizer = None |
|
|
_model = None |
|
|
_device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
@classmethod |
|
|
def load(cls): |
|
|
if cls._model is not None: |
|
|
return cls._tokenizer, cls._model, cls._device |
|
|
|
|
|
logger.info(f"[LOAD] Loading tokenizer: {MODEL_NAME}") |
|
|
cls._tokenizer = AutoTokenizer.from_pretrained( |
|
|
MODEL_NAME, use_fast=True |
|
|
) |
|
|
|
|
|
logger.info(f"[LOAD] Loading model on {cls._device}") |
|
|
cls._model = AutoModelForCausalLM.from_pretrained( |
|
|
MODEL_NAME, |
|
|
torch_dtype=torch.float16 if cls._device == "cuda" else None, |
|
|
low_cpu_mem_usage=True |
|
|
) |
|
|
|
|
|
cls._model.to(cls._device) |
|
|
cls._model.eval() |
|
|
|
|
|
logger.info("[READY] Shared LLM loaded once and ready.") |
|
|
return cls._tokenizer, cls._model, cls._device |
|
|
|
|
|
|
|
|
class PlutusModel: |
|
|
def __init__(self): |
|
|
self.tokenizer, self.model, self.device = SharedLLM.load() |
|
|
|
|
|
|
|
|
def create_prompt( |
|
|
self, |
|
|
personality: str, |
|
|
level: str, |
|
|
topic: str, |
|
|
extra_context: str = None |
|
|
) -> str: |
|
|
|
|
|
prompt = ( |
|
|
"You are PlutusTutor — the best expert in Cardano's Plutus smart contract ecosystem.\n\n" |
|
|
f"User Info:\n" |
|
|
f"- Personality: {personality}\n" |
|
|
f"- Level: {level}\n" |
|
|
f"- Topic: {topic}\n\n" |
|
|
"Your task:\n" |
|
|
"- Teach with extreme clarity.\n" |
|
|
"- Give structured explanations.\n" |
|
|
"- Include examples and code when needed.\n" |
|
|
"- Avoid useless filler.\n" |
|
|
"- Adapt tone slightly to the user's personality.\n\n" |
|
|
) |
|
|
|
|
|
if extra_context: |
|
|
prompt += f"Additional Context:\n{extra_context}\n\n" |
|
|
|
|
|
prompt += "Begin teaching now.\n\nAssistant:" |
|
|
return prompt |
|
|
|
|
|
|
|
|
def generate( |
|
|
self, |
|
|
prompt: str, |
|
|
max_new_tokens: int = 512, |
|
|
temperature: float = 0.7, |
|
|
top_p: float = 0.9 |
|
|
) -> str: |
|
|
|
|
|
try: |
|
|
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) |
|
|
|
|
|
with torch.inference_mode(): |
|
|
outputs = self.model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=max_new_tokens, |
|
|
do_sample=True, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
eos_token_id=self.tokenizer.eos_token_id, |
|
|
pad_token_id=self.tokenizer.pad_token_id, |
|
|
) |
|
|
|
|
|
|
|
|
new_tokens = outputs[0][inputs["input_ids"].shape[-1]:] |
|
|
return self.tokenizer.decode( |
|
|
new_tokens, skip_special_tokens=True |
|
|
).strip() |
|
|
|
|
|
except Exception as e: |
|
|
logger.exception("Generation failed") |
|
|
return f"[Generation Error] {e}" |
|
|
|
|
|
|
|
|
def stream_generate( |
|
|
self, |
|
|
prompt: str, |
|
|
max_new_tokens: int = 300, |
|
|
temperature: float = 0.5, |
|
|
top_p: float = 0.85 |
|
|
) -> Generator[str, None, None]: |
|
|
|
|
|
try: |
|
|
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) |
|
|
|
|
|
streamer = TextIteratorStreamer( |
|
|
self.tokenizer, |
|
|
skip_prompt=True, |
|
|
skip_special_tokens=True |
|
|
) |
|
|
|
|
|
def _run(): |
|
|
with torch.inference_mode(): |
|
|
self.model.generate( |
|
|
**inputs, |
|
|
streamer=streamer, |
|
|
max_new_tokens=max_new_tokens, |
|
|
do_sample=True, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
eos_token_id=self.tokenizer.eos_token_id, |
|
|
pad_token_id=self.tokenizer.pad_token_id, |
|
|
) |
|
|
|
|
|
Thread(target=_run, daemon=True).start() |
|
|
|
|
|
|
|
|
for chunk in streamer: |
|
|
yield chunk |
|
|
|
|
|
except Exception as e: |
|
|
logger.exception("Streaming failed") |
|
|
yield f"[Streaming Error] {e}" |
|
|
|
|
|
|
|
|
def summarize_recommendations( |
|
|
self, |
|
|
topic: str, |
|
|
items: List[Dict[str, Any]], |
|
|
personality: str = None, |
|
|
level: str = None, |
|
|
max_new_tokens: int = 120 |
|
|
) -> str: |
|
|
|
|
|
refs = "\n".join( |
|
|
f"- {item['type'].upper()}: {item.get('title') or item.get('url')} ({item['url']})" |
|
|
for item in items |
|
|
) |
|
|
|
|
|
prompt = ( |
|
|
f"The user is learning: {topic}\n\n" |
|
|
"Here are recommended resources:\n\n" |
|
|
f"{refs}\n\n" |
|
|
"Explain clearly why these are perfect for the user.\n" |
|
|
f"Personality: {personality}\n" |
|
|
f"Skill Level: {level}\n\nAssistant:" |
|
|
) |
|
|
|
|
|
return self.generate(prompt, max_new_tokens=max_new_tokens) |
|
|
|
|
|
|
|
|
|
|
|
class SummaryModel: |
|
|
def __init__(self): |
|
|
self.tokenizer, self.model, self.device = SharedLLM.load() |
|
|
|
|
|
def summarize_text( |
|
|
self, |
|
|
full_teaching: str, |
|
|
topic: str, |
|
|
level: str, |
|
|
recommended: List[Dict[str, Any]], |
|
|
max_new_tokens: int = 350 |
|
|
) -> str: |
|
|
|
|
|
refs = "\n".join( |
|
|
f"- {item['type'].upper()}: {item.get('title') or item.get('url')} ({item['url']})" |
|
|
for item in recommended |
|
|
) if recommended else "None" |
|
|
|
|
|
prompt = ( |
|
|
"You are a world-class summarization assistant.\n\n" |
|
|
f"TOPIC: {topic}\n" |
|
|
f"LEVEL: {level}\n\n" |
|
|
"CONTENT:\n" |
|
|
f"{full_teaching}\n\n" |
|
|
"Produce a clear, structured summary.\n" |
|
|
"Then recommend these resources:\n\n" |
|
|
f"{refs}\n\nAssistant:" |
|
|
) |
|
|
|
|
|
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) |
|
|
|
|
|
with torch.inference_mode(): |
|
|
out = self.model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=max_new_tokens, |
|
|
do_sample=True, |
|
|
temperature=0.6, |
|
|
top_p=0.85, |
|
|
eos_token_id=self.tokenizer.eos_token_id, |
|
|
) |
|
|
|
|
|
new_tokens = out[0][inputs["input_ids"].shape[-1]:] |
|
|
return self.tokenizer.decode( |
|
|
new_tokens, skip_special_tokens=True |
|
|
).strip() |
|
|
|
|
|
|