Update endpoints.py
Browse files- endpoints.py +14 -4
endpoints.py
CHANGED
|
@@ -45,19 +45,29 @@ model = AutoModelForCausalLM.from_pretrained("WizardLM/WizardCoder-1B-V1.0")
|
|
| 45 |
# )
|
| 46 |
# hf_llm = HuggingFacePipeline(pipeline=pipe)
|
| 47 |
|
| 48 |
-
|
| 49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
try:
|
| 51 |
# Prepare input prompt
|
| 52 |
input_prompt = ""
|
| 53 |
-
for message in messages:
|
| 54 |
role = message.get('role', 'user')
|
| 55 |
content = message.get('content', '')
|
| 56 |
input_prompt += f"{role}: {content}\n"
|
| 57 |
|
| 58 |
# Tokenize and generate response
|
| 59 |
input_ids = tokenizer.encode(input_prompt, return_tensors='pt')
|
| 60 |
-
output = model.generate(input_ids, max_length=1024, temperature=temperature, max_tokens=max_tokens)
|
| 61 |
|
| 62 |
# Decode and send response
|
| 63 |
response = tokenizer.decode(output[0], skip_special_tokens=True)
|
|
|
|
| 45 |
# )
|
| 46 |
# hf_llm = HuggingFacePipeline(pipeline=pipe)
|
| 47 |
|
| 48 |
+
|
| 49 |
+
class ChatRequest(BaseModel):
|
| 50 |
+
messages: list
|
| 51 |
+
temperature: float = 1.0
|
| 52 |
+
max_tokens: int = 50
|
| 53 |
+
stream: bool = False
|
| 54 |
+
|
| 55 |
+
class ChatResponse(BaseModel):
|
| 56 |
+
response: str
|
| 57 |
+
|
| 58 |
+
@app.post("/v1/chat/completions", response_model=ChatResponse)
|
| 59 |
+
async def chat_completions(request: ChatRequest):
|
| 60 |
try:
|
| 61 |
# Prepare input prompt
|
| 62 |
input_prompt = ""
|
| 63 |
+
for message in request.messages:
|
| 64 |
role = message.get('role', 'user')
|
| 65 |
content = message.get('content', '')
|
| 66 |
input_prompt += f"{role}: {content}\n"
|
| 67 |
|
| 68 |
# Tokenize and generate response
|
| 69 |
input_ids = tokenizer.encode(input_prompt, return_tensors='pt')
|
| 70 |
+
output = model.generate(input_ids, max_length=1024, temperature=request.temperature, max_tokens=request.max_tokens)
|
| 71 |
|
| 72 |
# Decode and send response
|
| 73 |
response = tokenizer.decode(output[0], skip_special_tokens=True)
|