File size: 4,365 Bytes
d0deb09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
from typing import Dict, Generator, List, Optional

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

from system_prompt import build_default_system_prompt


class IlographChatModel(object):
    """
    Thin OOP wrapper around the Qwen3 Ilograph model.

    Responsibility:
    - Load tokenizer/model once at startup
    - Expose a simple streaming text interface for chat-style messages
    """

    def __init__(
        self,
        model_id="Brigham-Young-University/Qwen2.5-Coder-3B-Ilograph-Instruct",
        device_map=None,
        dtype=None,
    ):
        self.model_id = model_id

        # Choose sensible defaults based on available hardware.
        if device_map is None or dtype is None:
            if torch.cuda.is_available():
                # On GPU (e.g. HF Space with GPU), let transformers/accelerate
                # decide how to place weights and use bfloat16 for speed.
                if device_map is None:
                    device_map = "auto"
                if dtype is None:
                    dtype = torch.bfloat16
            else:
                # On CPU-only (local machine or CPU Space), force everything
                # onto CPU with full precision for correctness.
                if device_map is None:
                    device_map = {"": "cpu"}
                if dtype is None:
                    dtype = torch.float32

        self.device_map = device_map
        self.dtype = dtype

        self.tokenizer = AutoTokenizer.from_pretrained(
            model_id,
            trust_remote_code=True,
        )
        self.model = AutoModelForCausalLM.from_pretrained(
            model_id,
            trust_remote_code=True,
            torch_dtype=self.dtype,
            device_map=self.device_map,
        )

        if self.tokenizer.pad_token_id is None:
            # Many causal LMs do not define a pad token, but generate() expects one
            self.tokenizer.pad_token = self.tokenizer.eos_token

        # Cache the default system prompt so we only load the schema once
        self._default_system_prompt = build_default_system_prompt()

    @property
    def default_system_prompt(self):
        return self._default_system_prompt

    def build_messages(
        self,
        system_prompt,
        history,
        user_message,
    ):
        messages = []

        system = system_prompt.strip() if system_prompt else self.default_system_prompt
        messages.append({"role": "system", "content": system})

        # Gradio already provides history as {role, content} dicts
        if history:
            messages.extend(history)

        messages.append({"role": "user", "content": user_message})

        return messages

    def generate_stream(
        self,
        messages,
        max_tokens,
        temperature,
        top_p,
    ):
        """
        Synchronous "streaming" generator for Gradio.

        For simplicity we generate the full response once and then
        yield it in small chunks so the UI can update incrementally.
        """
        # Qwen's apply_chat_template can return either a tensor or a BatchEncoding.
        encoded = self.tokenizer.apply_chat_template(
            messages,
            add_generation_prompt=True,
            return_tensors="pt",
        )

        # Normalise to a plain tensor so generate() receives the right type.
        if hasattr(encoded, "input_ids"):
            input_ids = encoded["input_ids"]
        else:
            input_ids = encoded

        input_ids = input_ids.to(self.model.device)

        with torch.no_grad():
            output_ids = self.model.generate(
                input_ids=input_ids,
                max_new_tokens=256,
                temperature=0.5,
                do_sample=True,
                eos_token_id=self.tokenizer.eos_token_id,
                pad_token_id=self.tokenizer.pad_token_id,
            )

        # Only keep newly generated tokens
        generated_ids = output_ids[0, input_ids.shape[-1] :]
        full_text = self.tokenizer.decode(
            generated_ids,
            skip_special_tokens=True,
        )

        # Wrap in a Markdown code block so the chat UI preserves
        # spaces and indentation (critical for YAML / IDL output).
        formatted = "```yaml\n" + full_text.strip("\n") + "\n```"
        yield formatted