File size: 3,894 Bytes
982f6ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60a854a
 
 
 
 
 
 
 
 
 
 
 
 
 
982f6ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6619c0f
982f6ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
"""
Custom handler for Hugging Face Inference Endpoints
Model: ongilLabs/IB-Math-Instruct-7B
"""

from typing import Dict, List, Any
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch


class EndpointHandler:
    def __init__(self, path: str = ""):
        """Initialize the model and tokenizer."""
        self.tokenizer = AutoTokenizer.from_pretrained(
            path,
            trust_remote_code=True
        )
        self.model = AutoModelForCausalLM.from_pretrained(
            path,
            torch_dtype=torch.bfloat16,
            device_map="auto",
            trust_remote_code=True
        )
        self.model.eval()
        
        # Default system prompt - IB Math AA Tutor style
        self.default_system = """You are an expert IB Mathematics AA tutor. Your role is to explain mathematical concepts and solve problems using pure mathematical reasoning, NOT programming code.

CRITICAL RULES:
1. NEVER write Python, SymPy, or any programming code
2. Use ONLY mathematical notation and LaTeX ($...$ for inline, $$...$$ for display)
3. Show step-by-step solutions with clear mathematical reasoning
4. Use IB command terms appropriately (Find, Show, Hence, Prove, etc.)
5. Include common pitfall warnings when relevant
6. End with IB exam tips about marking schemes (M marks, A marks)
7. Write in a teacher-like, encouraging tone
8. Use <think> tags to show your reasoning process before the solution

Your responses should be like a professional IB teacher explaining to students, using mathematical notation and clear explanations."""

    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        """
        Handle inference request.
        
        Args:
            data: Dictionary with 'inputs' (str or list) and optional 'parameters'
            
        Returns:
            Dictionary with 'generated_text'
        """
        inputs = data.get("inputs", "")
        parameters = data.get("parameters", {})
        
        # Extract parameters with defaults
        max_new_tokens = parameters.get("max_new_tokens", 1024)  # Increased to prevent truncation
        temperature = parameters.get("temperature", 0.7)
        top_p = parameters.get("top_p", 0.9)
        system_prompt = parameters.get("system_prompt", self.default_system)
        
        # Handle both string and message list inputs
        if isinstance(inputs, str):
            messages = [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": inputs}
            ]
        elif isinstance(inputs, list):
            # Assume it's already a list of messages
            messages = inputs
            # Prepend system if not present
            if messages and messages[0].get("role") != "system":
                messages = [{"role": "system", "content": system_prompt}] + messages
        else:
            return {"error": "Invalid input format. Expected string or list of messages."}
        
        # Apply chat template
        prompt = self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        
        input_ids = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
        
        with torch.no_grad():
            outputs = self.model.generate(
                **input_ids,
                max_new_tokens=max_new_tokens,
                temperature=temperature if temperature > 0 else None,
                top_p=top_p,
                do_sample=temperature > 0,
                pad_token_id=self.tokenizer.eos_token_id,
            )
        
        # Decode only new tokens (exclude prompt)
        response = self.tokenizer.decode(
            outputs[0][input_ids["input_ids"].shape[1]:],
            skip_special_tokens=True
        )
        
        return {"generated_text": response}