|
|
from typing import Dict, Any |
|
|
import torch |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
|
|
from typing import Dict, Any, List, Generator |
|
|
import time |
|
|
import torch |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, path: str = ""): |
|
|
self.tokenizer = AutoTokenizer.from_pretrained(path) |
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
|
path, |
|
|
torch_dtype=torch.float16, |
|
|
device_map="auto" |
|
|
) |
|
|
self.model_id = "askcatalystai/llama-ecommerce" |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
|
|
|
|
if "messages" in data: |
|
|
return self._handle_chat_completions(data) |
|
|
|
|
|
|
|
|
else: |
|
|
return self._handle_legacy_format(data) |
|
|
|
|
|
def _handle_chat_completions(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
|
"""Handle OpenAI Chat Completions API format""" |
|
|
messages = data.get("messages", []) |
|
|
model = data.get("model", self.model_id) |
|
|
temperature = data.get("temperature", 0.7) |
|
|
max_tokens = data.get("max_tokens", 200) |
|
|
|
|
|
|
|
|
prompt = self._messages_to_prompt(messages) |
|
|
|
|
|
|
|
|
input_ids = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model.generate( |
|
|
**input_ids, |
|
|
max_new_tokens=max_tokens, |
|
|
do_sample=temperature > 0, |
|
|
temperature=temperature, |
|
|
pad_token_id=self.tokenizer.eos_token_id |
|
|
) |
|
|
|
|
|
|
|
|
full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
response_content = self._extract_response(full_response) |
|
|
|
|
|
|
|
|
return { |
|
|
"id": f"cmpl-{int(time.time())}", |
|
|
"object": "chat.completion", |
|
|
"created": int(time.time()), |
|
|
"model": model, |
|
|
"choices": [ |
|
|
{ |
|
|
"index": 0, |
|
|
"message": { |
|
|
"role": "assistant", |
|
|
"content": response_content |
|
|
}, |
|
|
"finish_reason": "stop" |
|
|
} |
|
|
], |
|
|
"usage": { |
|
|
"prompt_tokens": len(input_ids.input_ids[0]), |
|
|
"completion_tokens": len(outputs[0]) - len(input_ids.input_ids[0]), |
|
|
"total_tokens": len(outputs[0]) |
|
|
} |
|
|
} |
|
|
|
|
|
def _handle_legacy_format(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
|
"""Handle legacy direct text input format""" |
|
|
inputs = data.get("inputs", "") |
|
|
parameters = data.get("parameters", {}) |
|
|
|
|
|
max_new_tokens = parameters.get("max_new_tokens", 200) |
|
|
temperature = parameters.get("temperature", 0.7) |
|
|
top_p = parameters.get("top_p", 0.9) |
|
|
|
|
|
|
|
|
if isinstance(inputs, dict): |
|
|
instruction = inputs.get("instruction", "") |
|
|
product_details = inputs.get("product_details", "") |
|
|
prompt = f"***Instruction: {instruction}\n***Input: {product_details}\n***Response:" |
|
|
else: |
|
|
prompt = inputs |
|
|
|
|
|
|
|
|
input_ids = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model.generate( |
|
|
**input_ids, |
|
|
max_new_tokens=max_new_tokens, |
|
|
do_sample=True, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
pad_token_id=self.tokenizer.eos_token_id |
|
|
) |
|
|
|
|
|
|
|
|
full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
response = self._extract_response(full_response) |
|
|
|
|
|
return {"generated_text": response} |
|
|
|
|
|
def _messages_to_prompt(self, messages: List[Dict[str, str]]) -> str: |
|
|
"""Convert OpenAI messages format to LLaMA-E prompt format""" |
|
|
system_prompt = "You are a helpful e-commerce assistant that generates product descriptions, advertisements, and marketing content." |
|
|
user_content = "" |
|
|
|
|
|
for msg in messages: |
|
|
role = msg.get("role", "") |
|
|
content = msg.get("content", "") |
|
|
|
|
|
if role == "system": |
|
|
system_prompt = content |
|
|
elif role == "user": |
|
|
user_content = content |
|
|
|
|
|
|
|
|
prompt = f"***System: {system_prompt}\n***User: {user_content}\n***Response:" |
|
|
return prompt |
|
|
|
|
|
def _extract_response(self, full_response: str) -> str: |
|
|
"""Extract the assistant response from generated text""" |
|
|
if "***Response:" in full_response: |
|
|
return full_response.split("***Response:")[1].strip() |
|
|
elif "***User:" in full_response: |
|
|
|
|
|
parts = full_response.split("***User:") |
|
|
if len(parts) > 1: |
|
|
return parts[-1].strip() |
|
|
return full_response |
|
|
|