from typing import Dict, Generator, List, Optional import torch from transformers import AutoModelForCausalLM, AutoTokenizer from system_prompt import build_default_system_prompt class IlographChatModel(object): """ Thin OOP wrapper around the Qwen3 Ilograph model. Responsibility: - Load tokenizer/model once at startup - Expose a simple streaming text interface for chat-style messages """ def __init__( self, model_id="Brigham-Young-University/Qwen2.5-Coder-3B-Ilograph-Instruct", device_map=None, dtype=None, ): self.model_id = model_id # Choose sensible defaults based on available hardware. if device_map is None or dtype is None: if torch.cuda.is_available(): # On GPU (e.g. HF Space with GPU), let transformers/accelerate # decide how to place weights and use bfloat16 for speed. if device_map is None: device_map = "auto" if dtype is None: dtype = torch.bfloat16 else: # On CPU-only (local machine or CPU Space), force everything # onto CPU with full precision for correctness. if device_map is None: device_map = {"": "cpu"} if dtype is None: dtype = torch.float32 self.device_map = device_map self.dtype = dtype self.tokenizer = AutoTokenizer.from_pretrained( model_id, trust_remote_code=True, ) self.model = AutoModelForCausalLM.from_pretrained( model_id, trust_remote_code=True, torch_dtype=self.dtype, device_map=self.device_map, ) if self.tokenizer.pad_token_id is None: # Many causal LMs do not define a pad token, but generate() expects one self.tokenizer.pad_token = self.tokenizer.eos_token # Cache the default system prompt so we only load the schema once self._default_system_prompt = build_default_system_prompt() @property def default_system_prompt(self): return self._default_system_prompt def build_messages( self, system_prompt, history, user_message, ): messages = [] system = system_prompt.strip() if system_prompt else self.default_system_prompt messages.append({"role": "system", "content": system}) # Gradio already provides history as {role, content} dicts if history: messages.extend(history) messages.append({"role": "user", "content": user_message}) return messages def generate_stream( self, messages, max_tokens, temperature, top_p, ): """ Synchronous "streaming" generator for Gradio. For simplicity we generate the full response once and then yield it in small chunks so the UI can update incrementally. """ # Qwen's apply_chat_template can return either a tensor or a BatchEncoding. encoded = self.tokenizer.apply_chat_template( messages, add_generation_prompt=True, return_tensors="pt", ) # Normalise to a plain tensor so generate() receives the right type. if hasattr(encoded, "input_ids"): input_ids = encoded["input_ids"] else: input_ids = encoded input_ids = input_ids.to(self.model.device) with torch.no_grad(): output_ids = self.model.generate( input_ids=input_ids, max_new_tokens=256, temperature=0.5, do_sample=True, eos_token_id=self.tokenizer.eos_token_id, pad_token_id=self.tokenizer.pad_token_id, ) # Only keep newly generated tokens generated_ids = output_ids[0, input_ids.shape[-1] :] full_text = self.tokenizer.decode( generated_ids, skip_special_tokens=True, ) # Wrap in a Markdown code block so the chat UI preserves # spaces and indentation (critical for YAML / IDL output). formatted = "```yaml\n" + full_text.strip("\n") + "\n```" yield formatted