File size: 8,522 Bytes
541346a
 
 
578513b
 
 
 
 
 
 
 
 
 
 
541346a
 
5f01640
541346a
 
 
578513b
541346a
5f01640
541346a
 
 
578513b
541346a
 
51d5f26
 
 
 
 
 
541346a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
578513b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
541346a
578513b
541346a
51d5f26
 
 
578513b
541346a
 
51d5f26
541346a
51d5f26
 
 
5f01640
578513b
51d5f26
578513b
541346a
ac7d3aa
541346a
 
 
 
51d5f26
578513b
 
51d5f26
578513b
 
51d5f26
578513b
51d5f26
578513b
 
51d5f26
 
578513b
 
51d5f26
541346a
 
 
 
 
 
 
 
51d5f26
541346a
578513b
51d5f26
541346a
 
 
 
51d5f26
541346a
578513b
51d5f26
 
541346a
 
 
 
51d5f26
541346a
 
 
 
 
 
 
578513b
541346a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
578513b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
"""
HuggingFace Inference Endpoint custom handler for SinLlama-MCQ.

Model stack: Llama base β†’ SinLlama LoRA (merge) β†’ MCQ LoRA β†’ inference

Environment variables (set in the HF Endpoint dashboard):
  SINLLAMA_REPO   SinLlama HF repo ID        (default: polyglots/SinLlama_v01)
  BASE_MODEL      Llama base HF repo ID       (auto-read from SinLlama config if unset)
  HF_TOKEN        HF token for gated/private repos
  TEMPERATURE     Generation temperature      (default: 0.7)
  MAX_NEW_TOKENS  Max tokens to generate      (default: 300)

Request:   {"inputs": "<Sinhala passage>"}
Response:  [{"generated_text": "...", "mcq": "...", "valid": true}]
"""

import json
import os
import re
import logging
import traceback
import torch
from pathlib import Path
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel, PeftConfig

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Use print() in addition to logger β€” HF toolkit may suppress the logger
# before the handler module is imported, but print() always reaches stdout.
def _log(msg: str) -> None:
    print(msg, flush=True)
    logger.info(msg)

_PROMPT_TEMPLATE = (
    "ΰΆ΄ΰ·„ΰΆ­ ࢉࢭිහාස ࢑ේࢯࢺ ΰΆšΰ·’ΰΆΊΰ·€ΰ·, ΰΆ’ ࢜ැࢱ ΰΆΆΰ·„ΰ·”-ΰ·€ΰ·’ΰΆšΰΆ½ΰ·ŠΰΆ΄ ΰΆ΄ΰ·Šβ€ΰΆ»ΰ·ΰ·ŠΰΆ±ΰΆΊΰΆšΰ·Š ΰ·ƒΰ·ΰΆ―ΰΆ±ΰ·ŠΰΆ±.\n\n"
    "࢑ේࢯࢺ: {passage}\n\n"
    "MCQ:"
)


def _normalize_mcq(text: str) -> str | None:
    text = re.sub(r"\b([ABCD])[.:]", r"\1)", text)
    for tag in ["A)", "B)", "C)", "D)", "ࢱිවැࢻࢯි ΰΆ΄ΰ·’ΰ·…ΰ·’ΰΆ­ΰ·”ΰΆ»:"]:
        text = re.sub(rf"(?<!\n)({re.escape(tag)})", r"\n\1", text)
    text = text.strip()
    if "ΰΆ΄ΰ·Šβ€ΰΆ»ΰ·ΰ·ŠΰΆ±ΰΆΊ:" not in text:
        lines = text.splitlines()
        if lines:
            text = "ΰΆ΄ΰ·Šβ€ΰΆ»ΰ·ΰ·ŠΰΆ±ΰΆΊ: " + lines[0].strip() + "\n" + "\n".join(lines[1:])
    if not all(f"{l})" in text for l in "ABCD"):
        return None
    if not re.search(r"ࢱිවැࢻࢯි ΰΆ΄ΰ·’ΰ·…ΰ·’ΰΆ­ΰ·”ΰΆ»:\s*[ABCD]", text):
        return None
    return text.strip()


def _patch_tokenizer_config(path: str) -> None:
    """Remove unknown tokenizer_class entries that crash AutoTokenizer."""
    tc_path = Path(path) / "tokenizer_config.json"
    if not tc_path.exists():
        return
    with open(tc_path, encoding="utf-8") as f:
        tc = json.load(f)
    known = {None, "LlamaTokenizer", "LlamaTokenizerFast",
             "PreTrainedTokenizer", "PreTrainedTokenizerFast"}
    if tc.get("tokenizer_class") not in known:
        logger.info("Patching tokenizer_config.json: dropping tokenizer_class=%r",
                    tc.pop("tokenizer_class"))
        with open(tc_path, "w", encoding="utf-8") as f:
            json.dump(tc, f, ensure_ascii=False, indent=2)


def _read_adapter_config(path: str, hf_token: str | None) -> dict:
    """Read adapter_config.json, logging its full contents for debugging."""
    try:
        cfg = PeftConfig.from_pretrained(path, token=hf_token)
        logger.info("adapter_config from %s: base_model=%r, peft_type=%r",
                    path, cfg.base_model_name_or_path, cfg.peft_type)
        return cfg
    except Exception:
        logger.error("Failed to read adapter config from %s:\n%s",
                     path, traceback.format_exc())
        raise


class EndpointHandler:
    def __init__(self, path: str = ""):
        _log("=" * 60)
        _log(f"EndpointHandler.__init__  path={path!r}")
        _log("=" * 60)

        sinllama_repo = os.environ.get("SINLLAMA_REPO", "polyglots/SinLlama_v01")
        hf_token      = os.environ.get("HF_TOKEN")
        _log(f"SINLLAMA_REPO={sinllama_repo!r}  HF_TOKEN={'set' if hf_token else 'NOT SET'}")

        # List repo files so we can verify what was actually deployed
        repo_files = sorted(str(p.name) for p in Path(path).iterdir()) if path else []
        _log(f"Repo files: {repo_files}")

        # ── 1. Tokenizer ──────────────────────────────────────────────────────
        _log("STEP 1: patch tokenizer_config.json + load tokenizer")
        _patch_tokenizer_config(path)
        self.tokenizer = AutoTokenizer.from_pretrained(
            path, use_fast=False, token=hf_token
        )
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        sinllama_vocab = len(self.tokenizer)
        _log(f"STEP 1 OK β€” vocab size {sinllama_vocab}")

        # ── 2. Resolve Llama base model ID ────────────────────────────────────
        _log("STEP 2: resolve base model ID")
        base_model_id = os.environ.get("BASE_MODEL")
        if base_model_id:
            _log(f"BASE_MODEL from env: {base_model_id!r}")
        else:
            _log("BASE_MODEL not set β€” reading from SinLlama adapter config")
            sinllama_cfg  = _read_adapter_config(sinllama_repo, hf_token)
            base_model_id = sinllama_cfg.base_model_name_or_path
            _log(f"Base model from SinLlama config: {base_model_id!r}")
        _log(f"STEP 2 OK β€” base_model_id={base_model_id!r}")

        # ── 3. Load Llama base ────────────────────────────────────────────────
        _log(f"STEP 3: load base model {base_model_id!r}")
        base = AutoModelForCausalLM.from_pretrained(
            base_model_id,
            torch_dtype=torch.bfloat16,
            device_map="auto",
            attn_implementation="sdpa",
            token=hf_token,
        )
        base.resize_token_embeddings(sinllama_vocab)
        _log(f"STEP 3 OK β€” embeddings resized to {sinllama_vocab}")

        # ── 4. Load + merge SinLlama adapter ─────────────────────────────────
        _log(f"STEP 4: load SinLlama adapter from {sinllama_repo!r}")
        sinllama = PeftModel.from_pretrained(
            base, sinllama_repo, is_trainable=False, token=hf_token
        )
        sinllama = sinllama.merge_and_unload()
        _log("STEP 4 OK β€” SinLlama merged")

        # ── 5. Load MCQ adapter ───────────────────────────────────────────────
        _log(f"STEP 5: load MCQ adapter from {path!r}")
        _read_adapter_config(path, hf_token)   # logs base_model_name_or_path
        self.model = PeftModel.from_pretrained(
            sinllama, path, is_trainable=False, token=hf_token
        )
        self.model.eval()
        _log("STEP 5 OK β€” EndpointHandler ready")

        self._default_temperature    = float(os.environ.get("TEMPERATURE", 0.7))
        self._default_max_new_tokens = int(os.environ.get("MAX_NEW_TOKENS", 300))

    def __call__(self, data: dict) -> list[dict]:
        passage = (data.get("inputs") or data.get("passage") or "").strip()
        if not passage:
            return [{"error": "No passage provided. Send {\"inputs\": \"<passage>\"}"}]

        params         = data.get("parameters") or {}
        temperature    = float(params.get("temperature",    self._default_temperature))
        max_new_tokens = int(params.get("max_new_tokens",   self._default_max_new_tokens))

        prompt = _PROMPT_TEMPLATE.format(passage=passage)
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)

        try:
            with torch.no_grad():
                output = self.model.generate(
                    **inputs,
                    max_new_tokens=max_new_tokens,
                    temperature=temperature,
                    do_sample=temperature > 0,
                    repetition_penalty=1.1,
                    eos_token_id=self.tokenizer.eos_token_id,
                    pad_token_id=self.tokenizer.pad_token_id,
                )
        except Exception as e:
            logger.exception("Generation failed")
            return [{"error": f"Generation error: {e}"}]

        new_ids = output[0][inputs.input_ids.shape[1]:]
        raw     = self.tokenizer.decode(new_ids, skip_special_tokens=True).strip()
        mcq     = _normalize_mcq(raw)

        return [{
            "generated_text": mcq if mcq is not None else raw,
            "mcq":            mcq,
            "valid":          mcq is not None,
        }]