ilograph-model / chat_model.py
HeyChriss's picture
model
d0deb09
from typing import Dict, Generator, List, Optional
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from system_prompt import build_default_system_prompt
class IlographChatModel(object):
"""
Thin OOP wrapper around the Qwen3 Ilograph model.
Responsibility:
- Load tokenizer/model once at startup
- Expose a simple streaming text interface for chat-style messages
"""
def __init__(
self,
model_id="Brigham-Young-University/Qwen2.5-Coder-3B-Ilograph-Instruct",
device_map=None,
dtype=None,
):
self.model_id = model_id
# Choose sensible defaults based on available hardware.
if device_map is None or dtype is None:
if torch.cuda.is_available():
# On GPU (e.g. HF Space with GPU), let transformers/accelerate
# decide how to place weights and use bfloat16 for speed.
if device_map is None:
device_map = "auto"
if dtype is None:
dtype = torch.bfloat16
else:
# On CPU-only (local machine or CPU Space), force everything
# onto CPU with full precision for correctness.
if device_map is None:
device_map = {"": "cpu"}
if dtype is None:
dtype = torch.float32
self.device_map = device_map
self.dtype = dtype
self.tokenizer = AutoTokenizer.from_pretrained(
model_id,
trust_remote_code=True,
)
self.model = AutoModelForCausalLM.from_pretrained(
model_id,
trust_remote_code=True,
torch_dtype=self.dtype,
device_map=self.device_map,
)
if self.tokenizer.pad_token_id is None:
# Many causal LMs do not define a pad token, but generate() expects one
self.tokenizer.pad_token = self.tokenizer.eos_token
# Cache the default system prompt so we only load the schema once
self._default_system_prompt = build_default_system_prompt()
@property
def default_system_prompt(self):
return self._default_system_prompt
def build_messages(
self,
system_prompt,
history,
user_message,
):
messages = []
system = system_prompt.strip() if system_prompt else self.default_system_prompt
messages.append({"role": "system", "content": system})
# Gradio already provides history as {role, content} dicts
if history:
messages.extend(history)
messages.append({"role": "user", "content": user_message})
return messages
def generate_stream(
self,
messages,
max_tokens,
temperature,
top_p,
):
"""
Synchronous "streaming" generator for Gradio.
For simplicity we generate the full response once and then
yield it in small chunks so the UI can update incrementally.
"""
# Qwen's apply_chat_template can return either a tensor or a BatchEncoding.
encoded = self.tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt",
)
# Normalise to a plain tensor so generate() receives the right type.
if hasattr(encoded, "input_ids"):
input_ids = encoded["input_ids"]
else:
input_ids = encoded
input_ids = input_ids.to(self.model.device)
with torch.no_grad():
output_ids = self.model.generate(
input_ids=input_ids,
max_new_tokens=256,
temperature=0.5,
do_sample=True,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.pad_token_id,
)
# Only keep newly generated tokens
generated_ids = output_ids[0, input_ids.shape[-1] :]
full_text = self.tokenizer.decode(
generated_ids,
skip_special_tokens=True,
)
# Wrap in a Markdown code block so the chat UI preserves
# spaces and indentation (critical for YAML / IDL output).
formatted = "```yaml\n" + full_text.strip("\n") + "\n```"
yield formatted