Text-to-3D
Transformers
Safetensors
English
jjohnson5253 commited on
Commit
828d794
·
1 Parent(s): 30c172b

add custom handler

Browse files
Files changed (2) hide show
  1. handler.py +152 -0
  2. requirements.txt +4 -0
handler.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any, Union
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ from peft import PeftModel
5
+ import json
6
+ import logging
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+ class EndpointHandler:
11
+ def __init__(self, path=""):
12
+ """
13
+ Initialize the handler with the PEFT adapter model
14
+ """
15
+ try:
16
+ # Load adapter config to get base model info
17
+ with open(f"{path}/adapter_config.json", "r") as f:
18
+ adapter_config = json.load(f)
19
+
20
+ base_model_name = adapter_config.get("base_model_name_or_path", "meta-llama/Llama-2-7b-chat-hf")
21
+ logger.info(f"Loading base model: {base_model_name}")
22
+
23
+ # Load tokenizer
24
+ self.tokenizer = AutoTokenizer.from_pretrained(base_model_name)
25
+ if self.tokenizer.pad_token is None:
26
+ self.tokenizer.pad_token = self.tokenizer.eos_token
27
+
28
+ # Load base model
29
+ base_model = AutoModelForCausalLM.from_pretrained(
30
+ base_model_name,
31
+ torch_dtype=torch.bfloat16,
32
+ device_map="auto",
33
+ trust_remote_code=True
34
+ )
35
+
36
+ # Load PEFT adapter
37
+ self.model = PeftModel.from_pretrained(
38
+ base_model,
39
+ path,
40
+ torch_dtype=torch.bfloat16,
41
+ device_map="auto"
42
+ )
43
+
44
+ self.model.eval()
45
+ logger.info("Model loaded successfully")
46
+
47
+ except Exception as e:
48
+ logger.error(f"Error loading model: {e}")
49
+ raise
50
+
51
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
52
+ """
53
+ Handle inference requests
54
+
55
+ Expected input format:
56
+ {
57
+ "inputs": "build a red car",
58
+ "parameters": {
59
+ "max_new_tokens": 512,
60
+ "temperature": 0.7,
61
+ "do_sample": true
62
+ }
63
+ }
64
+ """
65
+ try:
66
+ # Extract inputs
67
+ inputs = data.pop("inputs", "")
68
+ if isinstance(inputs, list):
69
+ inputs = inputs[0] if inputs else ""
70
+
71
+ # Extract parameters
72
+ parameters = data.pop("parameters", {})
73
+ max_new_tokens = parameters.get("max_new_tokens", 512)
74
+ temperature = parameters.get("temperature", 0.7)
75
+ do_sample = parameters.get("do_sample", True)
76
+ top_p = parameters.get("top_p", 0.9)
77
+ top_k = parameters.get("top_k", 50)
78
+
79
+ # Format prompt for BrickGPT (based on the training format)
80
+ formatted_prompt = self._format_prompt(inputs)
81
+
82
+ # Tokenize
83
+ input_ids = self.tokenizer.encode(
84
+ formatted_prompt,
85
+ return_tensors="pt",
86
+ truncation=True,
87
+ max_length=2048
88
+ ).to(self.model.device)
89
+
90
+ # Generate
91
+ with torch.no_grad():
92
+ output_ids = self.model.generate(
93
+ input_ids,
94
+ max_new_tokens=max_new_tokens,
95
+ temperature=temperature,
96
+ do_sample=do_sample,
97
+ top_p=top_p,
98
+ top_k=top_k,
99
+ pad_token_id=self.tokenizer.eos_token_id,
100
+ eos_token_id=self.tokenizer.eos_token_id,
101
+ repetition_penalty=1.1
102
+ )
103
+
104
+ # Decode only the generated part
105
+ generated_ids = output_ids[0][input_ids.shape[1]:]
106
+ generated_text = self.tokenizer.decode(
107
+ generated_ids,
108
+ skip_special_tokens=True,
109
+ clean_up_tokenization_spaces=True
110
+ )
111
+
112
+ # Parse the generated LDR content
113
+ ldr_instructions = self._parse_ldr_output(generated_text)
114
+
115
+ return [{
116
+ "generated_text": generated_text,
117
+ "ldr_instructions": ldr_instructions
118
+ }]
119
+
120
+ except Exception as e:
121
+ logger.error(f"Error during inference: {e}")
122
+ return [{"error": str(e)}]
123
+
124
+ def _format_prompt(self, user_input: str) -> str:
125
+ """
126
+ Format the input prompt for BrickGPT
127
+ Based on how the model was trained
128
+ """
129
+ # Clean and format the input
130
+ if not user_input:
131
+ user_input = "build something creative"
132
+
133
+ # Format similar to training data
134
+ prompt = f"### Instruction:\nGenerate LEGO building instructions for: {user_input}\n\n### Response:\n"
135
+ return prompt
136
+
137
+ def _parse_ldr_output(self, generated_text: str) -> List[str]:
138
+ """
139
+ Parse LDR instructions from generated text
140
+ """
141
+ ldr_lines = []
142
+ lines = generated_text.strip().split('\n')
143
+
144
+ for line in lines:
145
+ line = line.strip()
146
+ # LDR format lines typically start with numbers or specific commands
147
+ if line and (line.startswith('1 ') or line.startswith('0 ') or
148
+ line.startswith('2 ') or line.startswith('3 ') or
149
+ line.startswith('4 ') or line.startswith('5 ')):
150
+ ldr_lines.append(line)
151
+
152
+ return ldr_lines
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch>=2.0.0
2
+ transformers>=4.35.0
3
+ peft>=0.6.0
4
+ accelerate>=0.20.0