Text-to-3D
Transformers
Safetensors
English
File size: 3,458 Bytes
7daf744
828d794
8e6ec71
7daf744
828d794
 
 
 
7daf744
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
828d794
73a8f69
828d794
7daf744
 
 
73a8f69
837e58c
 
 
 
73a8f69
 
837e58c
 
73a8f69
837e58c
 
0be15f2
 
 
73a8f69
4c4a40c
262acca
4c4a40c
262acca
 
4c4a40c
7daf744
262acca
4c4a40c
262acca
 
4c4a40c
828d794
262acca
 
 
35a2486
73a8f69
7daf744
73a8f69
35a2486
 
 
7daf744
73a8f69
 
35a2486
828d794
 
73a8f69
7daf744
4c4a40c
 
 
 
 
7daf744
73a8f69
4c4a40c
 
262acca
73a8f69
 
828d794
73a8f69
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from typing import Dict, List, Any
import torch
import os
from transformers import AutoModelForCausalLM, AutoTokenizer


class EndpointHandler:
    def __init__(self, path=""):
        # Get HuggingFace token for gated model access
        hf_token = os.getenv("HF_TOKEN")
        
        # Load model and tokenizer with authentication
        self.tokenizer = AutoTokenizer.from_pretrained(
            path, 
            token=hf_token
        )
        self.model = AutoModelForCausalLM.from_pretrained(
            path,
            torch_dtype=torch.float16,
            device_map="auto",
            token=hf_token
        )
        
        # Set pad token if not exists
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
        Simple handler that mimics local LLM behavior for RemoteLLM
        """
        inputs = data.pop("inputs", data)
        parameters = data.pop("parameters", {})
        
        # Handle different input formats that RemoteLLM sends
        if isinstance(inputs, dict) and "messages" in inputs:
            messages = inputs["messages"]
        elif isinstance(inputs, list):
            messages = inputs
        else:
            # Fallback - treat as direct text
            messages = [
                {"role": "system", "content": "You are a helpful assistant."},
                {"role": "user", "content": str(inputs)}
            ]
        
        # Check if this is a continuation (has assistant message)
        has_assistant = any(msg.get("role") == "assistant" for msg in messages)
        
        # Apply chat template exactly like BrickGPT does locally
        if has_assistant:
            prompt = self.tokenizer.apply_chat_template(
                messages, 
                continue_final_message=True, 
                return_tensors='pt'
            )
        else:
            prompt = self.tokenizer.apply_chat_template(
                messages, 
                add_generation_prompt=True, 
                return_tensors='pt'
            )
        
        # Move to device
        input_ids = prompt.to(self.model.device)
        attention_mask = torch.ones_like(input_ids)
        
        # Generation parameters - use BrickGPT defaults
        generation_params = {
            "max_new_tokens": parameters.get("max_new_tokens", 10),
            "temperature": parameters.get("temperature", 0.6),
            "top_k": parameters.get("top_k", 20),
            "top_p": parameters.get("top_p", 1.0),
            "pad_token_id": self.tokenizer.pad_token_id,
            "do_sample": True,
            "num_return_sequences": 1,
            "return_dict_in_generate": True,
        }
        
        # Generate
        with torch.no_grad():
            output_dict = self.model.generate(
                input_ids,
                attention_mask=attention_mask,
                **generation_params
            )
        
        # Extract new tokens and decode EXACTLY like local LLM
        input_length = input_ids.shape[1]
        result_ids = output_dict['sequences'][0][input_length:]
        
        # CRITICAL: Decode exactly like local LLM (no skip_special_tokens parameter)
        generated_text = self.tokenizer.decode(result_ids)
        
        # Return in format RemoteLLM expects
        return [{"generated_text": generated_text}]