File size: 4,204 Bytes
2827a15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
"""
Custom handler for HuggingFace Inference Endpoints.
This handles loading the LoRA adapter on top of the base model.

To deploy:
1. Push this handler.py to your model repo on HuggingFace
2. Create an Inference Endpoint pointing to jimfhahn/bibframe-olmo-1b-v2
3. The endpoint will automatically use this custom handler
"""

from typing import Dict, Any
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel


class EndpointHandler:
    def __init__(self, path: str = ""):
        """
        Initialize the model and tokenizer.
        
        Args:
            path: Path to the model directory (provided by Inference Endpoints)
        """
        # Load base model
        base_model_id = "amd/AMD-OLMo-1B"
        
        self.tokenizer = AutoTokenizer.from_pretrained(base_model_id)
        
        # Load base model with appropriate settings
        self.model = AutoModelForCausalLM.from_pretrained(
            base_model_id,
            torch_dtype=torch.float16,
            device_map="auto",
            trust_remote_code=True,
        )
        
        # Load LoRA adapter from the endpoint path
        self.model = PeftModel.from_pretrained(
            self.model,
            path,  # This is the model repo path
            torch_dtype=torch.float16,
        )
        
        # Merge adapter for faster inference (optional but recommended)
        self.model = self.model.merge_and_unload()
        self.model.eval()
        
        # Set pad token if not set
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        """
        Process inference request.
        
        Args:
            data: Request payload with 'inputs' key containing RDF/XML to correct
            
        Returns:
            Dictionary with 'generated_text' containing corrected RDF/XML
        """
        # Extract input
        inputs = data.get("inputs", "")
        parameters = data.get("parameters", {})
        
        # Build prompt using the model's expected format
        prompt = self._build_prompt(inputs)
        
        # Tokenize
        encoded = self.tokenizer(
            prompt,
            return_tensors="pt",
            truncation=True,
            max_length=2048,
        ).to(self.model.device)
        
        # Generation parameters
        gen_kwargs = {
            "max_new_tokens": parameters.get("max_new_tokens", 1024),
            "temperature": parameters.get("temperature", 0.1),
            "top_p": parameters.get("top_p", 0.95),
            "do_sample": parameters.get("do_sample", True),
            "pad_token_id": self.tokenizer.pad_token_id,
            "eos_token_id": self.tokenizer.eos_token_id,
        }
        
        # Generate
        with torch.no_grad():
            outputs = self.model.generate(**encoded, **gen_kwargs)
        
        # Decode, removing the prompt
        generated = self.tokenizer.decode(
            outputs[0][encoded["input_ids"].shape[1]:],
            skip_special_tokens=True,
        )
        
        # Extract just the RDF/XML (stop at end markers if present)
        corrected = self._extract_rdf(generated)
        
        return {
            "generated_text": corrected,
            "prompt_used": prompt,
        }
    
    def _build_prompt(self, rdf_input: str) -> str:
        """Build the prompt in ChatML format (matching training data)."""
        return (
            "<|im_start|>system\n"
            "You are a BIBFRAME expert. Fix the following malformed RDF/XML "
            "to produce valid BIBFRAME following Library of Congress conventions.<|im_end|>\n"
            f"<|im_start|>user\n{rdf_input}<|im_end|>\n"
            "<|im_start|>assistant\n"
        )
    
    def _extract_rdf(self, text: str) -> str:
        """Extract RDF/XML from generated text, handling any trailing content."""
        # Try to find the closing rdf:RDF tag
        if "</rdf:RDF>" in text:
            end_idx = text.index("</rdf:RDF>") + len("</rdf:RDF>")
            return text[:end_idx].strip()
        return text.strip()