File size: 2,519 Bytes
45d4c43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
"""Custom handler for HF Inference Endpoints.

Loads Qwen2.5-0.5B base model, applies the LoRA adapter from this repo,
merges weights for faster inference, and serves predictions.
"""

from typing import Any, Dict, List, Union
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel


class EndpointHandler:
    def __init__(self, path: str = ""):
        base_model_id = "Qwen/Qwen2.5-0.5B"

        # Load base model
        base_model = AutoModelForCausalLM.from_pretrained(
            base_model_id,
            torch_dtype=torch.float16,
            device_map="auto",
            trust_remote_code=True,
        )

        # Apply LoRA adapter from this repo and merge
        model = PeftModel.from_pretrained(base_model, path)
        self.model = model.merge_and_unload()
        self.model.eval()

        # Load tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(
            base_model_id, trust_remote_code=True
        )
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, str]]:
        inputs = data.get("inputs", "")
        params = data.get("parameters", {})

        max_new_tokens = params.get("max_new_tokens", 256)
        temperature = params.get("temperature", 0.7)
        top_p = params.get("top_p", 0.9)

        # Support both plain string and chat-format inputs
        if isinstance(inputs, str):
            prompt = inputs
        elif isinstance(inputs, list):
            # Chat format: [{"role": "user", "content": "..."}]
            prompt = self.tokenizer.apply_chat_template(
                inputs, tokenize=False, add_generation_prompt=True
            )
        else:
            prompt = str(inputs)

        tokenized = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)

        with torch.no_grad():
            output_ids = self.model.generate(
                **tokenized,
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                top_p=top_p,
                do_sample=temperature > 0,
                pad_token_id=self.tokenizer.pad_token_id,
            )

        # Decode only the generated tokens (skip the prompt)
        new_tokens = output_ids[0][tokenized["input_ids"].shape[1]:]
        generated_text = self.tokenizer.decode(new_tokens, skip_special_tokens=True)

        return [{"generated_text": generated_text}]