llama-ecommerce / handler.py
askcatalystai's picture
Upload handler.py with huggingface_hub
7c0f07d verified
from typing import Dict, Any
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import Dict, Any, List, Generator
import time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
class EndpointHandler:
def __init__(self, path: str = ""):
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.model = AutoModelForCausalLM.from_pretrained(
path,
torch_dtype=torch.float16,
device_map="auto"
)
self.model_id = "askcatalystai/llama-ecommerce"
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
# Handle OpenAI Chat Completions format
if "messages" in data:
return self._handle_chat_completions(data)
# Handle direct text input (legacy format)
else:
return self._handle_legacy_format(data)
def _handle_chat_completions(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""Handle OpenAI Chat Completions API format"""
messages = data.get("messages", [])
model = data.get("model", self.model_id)
temperature = data.get("temperature", 0.7)
max_tokens = data.get("max_tokens", 200)
# Convert messages to prompt
prompt = self._messages_to_prompt(messages)
# Generate
input_ids = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
with torch.no_grad():
outputs = self.model.generate(
**input_ids,
max_new_tokens=max_tokens,
do_sample=temperature > 0,
temperature=temperature,
pad_token_id=self.tokenizer.eos_token_id
)
# Decode and extract response
full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
response_content = self._extract_response(full_response)
# Return OpenAI-compatible format
return {
"id": f"cmpl-{int(time.time())}",
"object": "chat.completion",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": response_content
},
"finish_reason": "stop"
}
],
"usage": {
"prompt_tokens": len(input_ids.input_ids[0]),
"completion_tokens": len(outputs[0]) - len(input_ids.input_ids[0]),
"total_tokens": len(outputs[0])
}
}
def _handle_legacy_format(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""Handle legacy direct text input format"""
inputs = data.get("inputs", "")
parameters = data.get("parameters", {})
max_new_tokens = parameters.get("max_new_tokens", 200)
temperature = parameters.get("temperature", 0.7)
top_p = parameters.get("top_p", 0.9)
# Format prompt if instruction/input provided separately
if isinstance(inputs, dict):
instruction = inputs.get("instruction", "")
product_details = inputs.get("product_details", "")
prompt = f"***Instruction: {instruction}\n***Input: {product_details}\n***Response:"
else:
prompt = inputs
# Tokenize and generate
input_ids = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
with torch.no_grad():
outputs = self.model.generate(
**input_ids,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
pad_token_id=self.tokenizer.eos_token_id
)
# Decode and extract
full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
response = self._extract_response(full_response)
return {"generated_text": response}
def _messages_to_prompt(self, messages: List[Dict[str, str]]) -> str:
"""Convert OpenAI messages format to LLaMA-E prompt format"""
system_prompt = "You are a helpful e-commerce assistant that generates product descriptions, advertisements, and marketing content."
user_content = ""
for msg in messages:
role = msg.get("role", "")
content = msg.get("content", "")
if role == "system":
system_prompt = content
elif role == "user":
user_content = content
# Format for LLaMA-E
prompt = f"***System: {system_prompt}\n***User: {user_content}\n***Response:"
return prompt
def _extract_response(self, full_response: str) -> str:
"""Extract the assistant response from generated text"""
if "***Response:" in full_response:
return full_response.split("***Response:")[1].strip()
elif "***User:" in full_response:
# Take text after last user message
parts = full_response.split("***User:")
if len(parts) > 1:
return parts[-1].strip()
return full_response