File size: 5,405 Bytes
dab0caa 7c0f07d dab0caa 7c0f07d dab0caa 7c0f07d dab0caa 7c0f07d dab0caa 7c0f07d dab0caa 7c0f07d dab0caa 7c0f07d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
from typing import Dict, Any
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import Dict, Any, List, Generator
import time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
class EndpointHandler:
def __init__(self, path: str = ""):
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.model = AutoModelForCausalLM.from_pretrained(
path,
torch_dtype=torch.float16,
device_map="auto"
)
self.model_id = "askcatalystai/llama-ecommerce"
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
# Handle OpenAI Chat Completions format
if "messages" in data:
return self._handle_chat_completions(data)
# Handle direct text input (legacy format)
else:
return self._handle_legacy_format(data)
def _handle_chat_completions(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""Handle OpenAI Chat Completions API format"""
messages = data.get("messages", [])
model = data.get("model", self.model_id)
temperature = data.get("temperature", 0.7)
max_tokens = data.get("max_tokens", 200)
# Convert messages to prompt
prompt = self._messages_to_prompt(messages)
# Generate
input_ids = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
with torch.no_grad():
outputs = self.model.generate(
**input_ids,
max_new_tokens=max_tokens,
do_sample=temperature > 0,
temperature=temperature,
pad_token_id=self.tokenizer.eos_token_id
)
# Decode and extract response
full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
response_content = self._extract_response(full_response)
# Return OpenAI-compatible format
return {
"id": f"cmpl-{int(time.time())}",
"object": "chat.completion",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": response_content
},
"finish_reason": "stop"
}
],
"usage": {
"prompt_tokens": len(input_ids.input_ids[0]),
"completion_tokens": len(outputs[0]) - len(input_ids.input_ids[0]),
"total_tokens": len(outputs[0])
}
}
def _handle_legacy_format(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""Handle legacy direct text input format"""
inputs = data.get("inputs", "")
parameters = data.get("parameters", {})
max_new_tokens = parameters.get("max_new_tokens", 200)
temperature = parameters.get("temperature", 0.7)
top_p = parameters.get("top_p", 0.9)
# Format prompt if instruction/input provided separately
if isinstance(inputs, dict):
instruction = inputs.get("instruction", "")
product_details = inputs.get("product_details", "")
prompt = f"***Instruction: {instruction}\n***Input: {product_details}\n***Response:"
else:
prompt = inputs
# Tokenize and generate
input_ids = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
with torch.no_grad():
outputs = self.model.generate(
**input_ids,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
pad_token_id=self.tokenizer.eos_token_id
)
# Decode and extract
full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
response = self._extract_response(full_response)
return {"generated_text": response}
def _messages_to_prompt(self, messages: List[Dict[str, str]]) -> str:
"""Convert OpenAI messages format to LLaMA-E prompt format"""
system_prompt = "You are a helpful e-commerce assistant that generates product descriptions, advertisements, and marketing content."
user_content = ""
for msg in messages:
role = msg.get("role", "")
content = msg.get("content", "")
if role == "system":
system_prompt = content
elif role == "user":
user_content = content
# Format for LLaMA-E
prompt = f"***System: {system_prompt}\n***User: {user_content}\n***Response:"
return prompt
def _extract_response(self, full_response: str) -> str:
"""Extract the assistant response from generated text"""
if "***Response:" in full_response:
return full_response.split("***Response:")[1].strip()
elif "***User:" in full_response:
# Take text after last user message
parts = full_response.split("***User:")
if len(parts) > 1:
return parts[-1].strip()
return full_response
|