File size: 5,405 Bytes
dab0caa
 
 
 
7c0f07d
 
 
 
 
dab0caa
 
 
 
 
 
 
 
7c0f07d
dab0caa
 
7c0f07d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dab0caa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7c0f07d
dab0caa
 
 
 
 
 
 
 
 
 
 
 
7c0f07d
dab0caa
7c0f07d
dab0caa
 
7c0f07d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
from typing import Dict, Any
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

from typing import Dict, Any, List, Generator
import time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

class EndpointHandler:
    def __init__(self, path: str = ""):
        self.tokenizer = AutoTokenizer.from_pretrained(path)
        self.model = AutoModelForCausalLM.from_pretrained(
            path,
            torch_dtype=torch.float16,
            device_map="auto"
        )
        self.model_id = "askcatalystai/llama-ecommerce"
    
    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        # Handle OpenAI Chat Completions format
        if "messages" in data:
            return self._handle_chat_completions(data)
        
        # Handle direct text input (legacy format)
        else:
            return self._handle_legacy_format(data)
    
    def _handle_chat_completions(self, data: Dict[str, Any]) -> Dict[str, Any]:
        """Handle OpenAI Chat Completions API format"""
        messages = data.get("messages", [])
        model = data.get("model", self.model_id)
        temperature = data.get("temperature", 0.7)
        max_tokens = data.get("max_tokens", 200)
        
        # Convert messages to prompt
        prompt = self._messages_to_prompt(messages)
        
        # Generate
        input_ids = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
        
        with torch.no_grad():
            outputs = self.model.generate(
                **input_ids,
                max_new_tokens=max_tokens,
                do_sample=temperature > 0,
                temperature=temperature,
                pad_token_id=self.tokenizer.eos_token_id
            )
        
        # Decode and extract response
        full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        response_content = self._extract_response(full_response)
        
        # Return OpenAI-compatible format
        return {
            "id": f"cmpl-{int(time.time())}",
            "object": "chat.completion",
            "created": int(time.time()),
            "model": model,
            "choices": [
                {
                    "index": 0,
                    "message": {
                        "role": "assistant",
                        "content": response_content
                    },
                    "finish_reason": "stop"
                }
            ],
            "usage": {
                "prompt_tokens": len(input_ids.input_ids[0]),
                "completion_tokens": len(outputs[0]) - len(input_ids.input_ids[0]),
                "total_tokens": len(outputs[0])
            }
        }
    
    def _handle_legacy_format(self, data: Dict[str, Any]) -> Dict[str, Any]:
        """Handle legacy direct text input format"""
        inputs = data.get("inputs", "")
        parameters = data.get("parameters", {})
        
        max_new_tokens = parameters.get("max_new_tokens", 200)
        temperature = parameters.get("temperature", 0.7)
        top_p = parameters.get("top_p", 0.9)
        
        # Format prompt if instruction/input provided separately
        if isinstance(inputs, dict):
            instruction = inputs.get("instruction", "")
            product_details = inputs.get("product_details", "")
            prompt = f"***Instruction: {instruction}\n***Input: {product_details}\n***Response:"
        else:
            prompt = inputs
        
        # Tokenize and generate
        input_ids = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
        
        with torch.no_grad():
            outputs = self.model.generate(
                **input_ids,
                max_new_tokens=max_new_tokens,
                do_sample=True,
                temperature=temperature,
                top_p=top_p,
                pad_token_id=self.tokenizer.eos_token_id
            )
        
        # Decode and extract
        full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        response = self._extract_response(full_response)
        
        return {"generated_text": response}
    
    def _messages_to_prompt(self, messages: List[Dict[str, str]]) -> str:
        """Convert OpenAI messages format to LLaMA-E prompt format"""
        system_prompt = "You are a helpful e-commerce assistant that generates product descriptions, advertisements, and marketing content."
        user_content = ""
        
        for msg in messages:
            role = msg.get("role", "")
            content = msg.get("content", "")
            
            if role == "system":
                system_prompt = content
            elif role == "user":
                user_content = content
        
        # Format for LLaMA-E
        prompt = f"***System: {system_prompt}\n***User: {user_content}\n***Response:"
        return prompt
    
    def _extract_response(self, full_response: str) -> str:
        """Extract the assistant response from generated text"""
        if "***Response:" in full_response:
            return full_response.split("***Response:")[1].strip()
        elif "***User:" in full_response:
            # Take text after last user message
            parts = full_response.split("***User:")
            if len(parts) > 1:
                return parts[-1].strip()
        return full_response