Update main.py
Browse files
main.py
CHANGED
|
@@ -7,8 +7,6 @@ import json
|
|
| 7 |
import logging
|
| 8 |
import asyncio
|
| 9 |
import time
|
| 10 |
-
import base64
|
| 11 |
-
from io import BytesIO
|
| 12 |
from collections import defaultdict
|
| 13 |
from typing import List, Dict, Any, Optional, AsyncGenerator, Union
|
| 14 |
|
|
@@ -18,7 +16,10 @@ from aiohttp import ClientSession, ClientTimeout, ClientError
|
|
| 18 |
from fastapi import FastAPI, HTTPException, Request, Depends, Header
|
| 19 |
from fastapi.responses import StreamingResponse, JSONResponse, RedirectResponse
|
| 20 |
from pydantic import BaseModel
|
|
|
|
| 21 |
from PIL import Image
|
|
|
|
|
|
|
| 22 |
|
| 23 |
# Configure logging
|
| 24 |
logging.basicConfig(
|
|
@@ -108,25 +109,6 @@ class ImageResponse:
|
|
| 108 |
def to_data_uri(image: Any) -> str:
|
| 109 |
return "data:image/png;base64,..." # Replace with actual base64 data
|
| 110 |
|
| 111 |
-
# Utility functions for image processing
|
| 112 |
-
def decode_base64_image(base64_str: str) -> Image.Image:
|
| 113 |
-
try:
|
| 114 |
-
image_data = base64.b64decode(base64_str)
|
| 115 |
-
image = Image.open(BytesIO(image_data))
|
| 116 |
-
return image
|
| 117 |
-
except Exception as e:
|
| 118 |
-
logger.error("Failed to decode base64 image.")
|
| 119 |
-
raise HTTPException(status_code=400, detail="Invalid base64 image data.") from e
|
| 120 |
-
|
| 121 |
-
def analyze_image(image: Image.Image) -> str:
|
| 122 |
-
"""
|
| 123 |
-
Placeholder for image analysis.
|
| 124 |
-
Replace this with actual image analysis logic.
|
| 125 |
-
"""
|
| 126 |
-
# Example: Return image size as analysis
|
| 127 |
-
width, height = image.size
|
| 128 |
-
return f"Image analyzed successfully. Width: {width}px, Height: {height}px."
|
| 129 |
-
|
| 130 |
class Blackbox:
|
| 131 |
url = "https://www.blackbox.ai"
|
| 132 |
api_endpoint = "https://www.blackbox.ai/api/chat"
|
|
@@ -440,7 +422,7 @@ async def security_middleware(request: Request, call_next):
|
|
| 440 |
# Request Models
|
| 441 |
class Message(BaseModel):
|
| 442 |
role: str
|
| 443 |
-
content: Union[str, List[Any]] #
|
| 444 |
|
| 445 |
class ChatRequest(BaseModel):
|
| 446 |
model: str
|
|
@@ -510,31 +492,59 @@ async def chat_completions(request: ChatRequest, req: Request, api_key: str = De
|
|
| 510 |
logger.exception("Unexpected error during image analysis.")
|
| 511 |
raise HTTPException(status_code=500, detail="Image analysis failed.") from e
|
| 512 |
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
|
| 519 |
-
|
| 520 |
-
|
| 521 |
-
|
| 522 |
-
|
| 523 |
-
|
| 524 |
-
|
| 525 |
-
|
| 526 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 527 |
|
|
|
|
| 528 |
if request.stream:
|
| 529 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 530 |
try:
|
| 531 |
-
|
| 532 |
-
async for chunk in async_generator:
|
| 533 |
if isinstance(chunk, ImageResponse):
|
| 534 |
# Handle image responses if necessary
|
| 535 |
image_markdown = f"\n"
|
| 536 |
assistant_content += image_markdown
|
| 537 |
-
response_chunk = create_response(image_markdown,
|
| 538 |
else:
|
| 539 |
assistant_content += chunk
|
| 540 |
# Yield the chunk as a partial choice
|
|
@@ -542,7 +552,7 @@ async def chat_completions(request: ChatRequest, req: Request, api_key: str = De
|
|
| 542 |
"id": f"chatcmpl-{uuid.uuid4()}",
|
| 543 |
"object": "chat.completion.chunk",
|
| 544 |
"created": int(datetime.now().timestamp()),
|
| 545 |
-
"model":
|
| 546 |
"choices": [
|
| 547 |
{
|
| 548 |
"index": 0,
|
|
@@ -555,7 +565,7 @@ async def chat_completions(request: ChatRequest, req: Request, api_key: str = De
|
|
| 555 |
yield f"data: {json.dumps(response_chunk)}\n\n"
|
| 556 |
|
| 557 |
# After all chunks are sent, send the final message with finish_reason
|
| 558 |
-
prompt_tokens = sum(len(msg
|
| 559 |
completion_tokens = len(assistant_content.split())
|
| 560 |
total_tokens = prompt_tokens + completion_tokens
|
| 561 |
estimated_cost = calculate_estimated_cost(prompt_tokens, completion_tokens)
|
|
@@ -564,7 +574,7 @@ async def chat_completions(request: ChatRequest, req: Request, api_key: str = De
|
|
| 564 |
"id": f"chatcmpl-{uuid.uuid4()}",
|
| 565 |
"object": "chat.completion",
|
| 566 |
"created": int(datetime.now().timestamp()),
|
| 567 |
-
"model":
|
| 568 |
"choices": [
|
| 569 |
{
|
| 570 |
"message": {
|
|
@@ -596,16 +606,26 @@ async def chat_completions(request: ChatRequest, req: Request, api_key: str = De
|
|
| 596 |
error_response = {"error": str(e)}
|
| 597 |
yield f"data: {json.dumps(error_response)}\n\n"
|
| 598 |
|
| 599 |
-
return StreamingResponse(
|
| 600 |
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 601 |
response_content = ""
|
| 602 |
-
async for chunk in
|
| 603 |
if isinstance(chunk, ImageResponse):
|
| 604 |
response_content += f"\n"
|
| 605 |
else:
|
| 606 |
response_content += chunk
|
| 607 |
|
| 608 |
-
prompt_tokens = sum(len(msg
|
| 609 |
completion_tokens = len(response_content.split())
|
| 610 |
total_tokens = prompt_tokens + completion_tokens
|
| 611 |
estimated_cost = calculate_estimated_cost(prompt_tokens, completion_tokens)
|
|
@@ -616,7 +636,7 @@ async def chat_completions(request: ChatRequest, req: Request, api_key: str = De
|
|
| 616 |
"id": f"chatcmpl-{uuid.uuid4()}",
|
| 617 |
"object": "chat.completion",
|
| 618 |
"created": int(datetime.now().timestamp()),
|
| 619 |
-
"model":
|
| 620 |
"choices": [
|
| 621 |
{
|
| 622 |
"message": {
|
|
@@ -710,6 +730,25 @@ async def http_exception_handler(request: Request, exc: HTTPException):
|
|
| 710 |
},
|
| 711 |
)
|
| 712 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 713 |
# Run the application
|
| 714 |
if __name__ == "__main__":
|
| 715 |
import uvicorn
|
|
|
|
| 7 |
import logging
|
| 8 |
import asyncio
|
| 9 |
import time
|
|
|
|
|
|
|
| 10 |
from collections import defaultdict
|
| 11 |
from typing import List, Dict, Any, Optional, AsyncGenerator, Union
|
| 12 |
|
|
|
|
| 16 |
from fastapi import FastAPI, HTTPException, Request, Depends, Header
|
| 17 |
from fastapi.responses import StreamingResponse, JSONResponse, RedirectResponse
|
| 18 |
from pydantic import BaseModel
|
| 19 |
+
|
| 20 |
from PIL import Image
|
| 21 |
+
import base64
|
| 22 |
+
from io import BytesIO
|
| 23 |
|
| 24 |
# Configure logging
|
| 25 |
logging.basicConfig(
|
|
|
|
| 109 |
def to_data_uri(image: Any) -> str:
|
| 110 |
return "data:image/png;base64,..." # Replace with actual base64 data
|
| 111 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
class Blackbox:
|
| 113 |
url = "https://www.blackbox.ai"
|
| 114 |
api_endpoint = "https://www.blackbox.ai/api/chat"
|
|
|
|
| 422 |
# Request Models
|
| 423 |
class Message(BaseModel):
|
| 424 |
role: str
|
| 425 |
+
content: Union[str, List[Any]] # content can be a string or a list (for images)
|
| 426 |
|
| 427 |
class ChatRequest(BaseModel):
|
| 428 |
model: str
|
|
|
|
| 492 |
logger.exception("Unexpected error during image analysis.")
|
| 493 |
raise HTTPException(status_code=500, detail="Image analysis failed.") from e
|
| 494 |
|
| 495 |
+
# Prepare messages to send to the external API, excluding image data
|
| 496 |
+
processed_messages = []
|
| 497 |
+
for msg in request.messages:
|
| 498 |
+
if isinstance(msg.content, list) and len(msg.content) == 2:
|
| 499 |
+
# Assume the second item is image data, skip it
|
| 500 |
+
processed_messages.append({
|
| 501 |
+
"role": msg.role,
|
| 502 |
+
"content": msg.content[0]["text"] # Only include the text part
|
| 503 |
+
})
|
| 504 |
+
else:
|
| 505 |
+
processed_messages.append({
|
| 506 |
+
"role": msg.role,
|
| 507 |
+
"content": msg.content
|
| 508 |
+
})
|
| 509 |
+
|
| 510 |
+
# Create a modified ChatRequest without the image
|
| 511 |
+
modified_request = ChatRequest(
|
| 512 |
+
model=request.model,
|
| 513 |
+
messages=[msg for msg in processed_messages],
|
| 514 |
+
stream=request.stream,
|
| 515 |
+
temperature=request.temperature,
|
| 516 |
+
top_p=request.top_p,
|
| 517 |
+
max_tokens=request.max_tokens,
|
| 518 |
+
presence_penalty=request.presence_penalty,
|
| 519 |
+
frequency_penalty=request.frequency_penalty,
|
| 520 |
+
logit_bias=request.logit_bias,
|
| 521 |
+
user=request.user,
|
| 522 |
+
webSearchMode=request.webSearchMode,
|
| 523 |
+
image=None # Exclude image from external API
|
| 524 |
+
)
|
| 525 |
|
| 526 |
+
try:
|
| 527 |
if request.stream:
|
| 528 |
+
logger.info("Streaming response")
|
| 529 |
+
streaming_response = await Blackbox.create_async_generator(
|
| 530 |
+
model=modified_request.model,
|
| 531 |
+
messages=[{"role": msg["role"], "content": msg["content"]} for msg in modified_request.messages],
|
| 532 |
+
proxy=None,
|
| 533 |
+
image=None,
|
| 534 |
+
image_name=None,
|
| 535 |
+
webSearchMode=modified_request.webSearchMode
|
| 536 |
+
)
|
| 537 |
+
|
| 538 |
+
# Wrap the streaming generator to include image analysis at the end
|
| 539 |
+
async def generate_with_analysis():
|
| 540 |
+
assistant_content = ""
|
| 541 |
try:
|
| 542 |
+
async for chunk in streaming_response:
|
|
|
|
| 543 |
if isinstance(chunk, ImageResponse):
|
| 544 |
# Handle image responses if necessary
|
| 545 |
image_markdown = f"\n"
|
| 546 |
assistant_content += image_markdown
|
| 547 |
+
response_chunk = create_response(image_markdown, modified_request.model, finish_reason=None)
|
| 548 |
else:
|
| 549 |
assistant_content += chunk
|
| 550 |
# Yield the chunk as a partial choice
|
|
|
|
| 552 |
"id": f"chatcmpl-{uuid.uuid4()}",
|
| 553 |
"object": "chat.completion.chunk",
|
| 554 |
"created": int(datetime.now().timestamp()),
|
| 555 |
+
"model": modified_request.model,
|
| 556 |
"choices": [
|
| 557 |
{
|
| 558 |
"index": 0,
|
|
|
|
| 565 |
yield f"data: {json.dumps(response_chunk)}\n\n"
|
| 566 |
|
| 567 |
# After all chunks are sent, send the final message with finish_reason
|
| 568 |
+
prompt_tokens = sum(len(msg["content"].split()) for msg in modified_request.messages)
|
| 569 |
completion_tokens = len(assistant_content.split())
|
| 570 |
total_tokens = prompt_tokens + completion_tokens
|
| 571 |
estimated_cost = calculate_estimated_cost(prompt_tokens, completion_tokens)
|
|
|
|
| 574 |
"id": f"chatcmpl-{uuid.uuid4()}",
|
| 575 |
"object": "chat.completion",
|
| 576 |
"created": int(datetime.now().timestamp()),
|
| 577 |
+
"model": modified_request.model,
|
| 578 |
"choices": [
|
| 579 |
{
|
| 580 |
"message": {
|
|
|
|
| 606 |
error_response = {"error": str(e)}
|
| 607 |
yield f"data: {json.dumps(error_response)}\n\n"
|
| 608 |
|
| 609 |
+
return StreamingResponse(generate_with_analysis(), media_type="text/event-stream")
|
| 610 |
else:
|
| 611 |
+
logger.info("Non-streaming response")
|
| 612 |
+
streaming_response = await Blackbox.create_async_generator(
|
| 613 |
+
model=modified_request.model,
|
| 614 |
+
messages=[{"role": msg["role"], "content": msg["content"]} for msg in modified_request.messages],
|
| 615 |
+
proxy=None,
|
| 616 |
+
image=None,
|
| 617 |
+
image_name=None,
|
| 618 |
+
webSearchMode=modified_request.webSearchMode
|
| 619 |
+
)
|
| 620 |
+
|
| 621 |
response_content = ""
|
| 622 |
+
async for chunk in streaming_response:
|
| 623 |
if isinstance(chunk, ImageResponse):
|
| 624 |
response_content += f"\n"
|
| 625 |
else:
|
| 626 |
response_content += chunk
|
| 627 |
|
| 628 |
+
prompt_tokens = sum(len(msg["content"].split()) for msg in modified_request.messages)
|
| 629 |
completion_tokens = len(response_content.split())
|
| 630 |
total_tokens = prompt_tokens + completion_tokens
|
| 631 |
estimated_cost = calculate_estimated_cost(prompt_tokens, completion_tokens)
|
|
|
|
| 636 |
"id": f"chatcmpl-{uuid.uuid4()}",
|
| 637 |
"object": "chat.completion",
|
| 638 |
"created": int(datetime.now().timestamp()),
|
| 639 |
+
"model": modified_request.model,
|
| 640 |
"choices": [
|
| 641 |
{
|
| 642 |
"message": {
|
|
|
|
| 730 |
},
|
| 731 |
)
|
| 732 |
|
| 733 |
+
# Image Processing Utilities
|
| 734 |
+
def decode_base64_image(base64_str: str) -> Image.Image:
|
| 735 |
+
try:
|
| 736 |
+
image_data = base64.b64decode(base64_str)
|
| 737 |
+
image = Image.open(BytesIO(image_data))
|
| 738 |
+
return image
|
| 739 |
+
except Exception as e:
|
| 740 |
+
logger.error("Failed to decode base64 image.")
|
| 741 |
+
raise HTTPException(status_code=400, detail="Invalid base64 image data.") from e
|
| 742 |
+
|
| 743 |
+
def analyze_image(image: Image.Image) -> str:
|
| 744 |
+
"""
|
| 745 |
+
Placeholder for image analysis.
|
| 746 |
+
Replace this with actual image analysis logic.
|
| 747 |
+
"""
|
| 748 |
+
# Example: Return image size as analysis
|
| 749 |
+
width, height = image.size
|
| 750 |
+
return f"Image analyzed successfully. Width: {width}px, Height: {height}px."
|
| 751 |
+
|
| 752 |
# Run the application
|
| 753 |
if __name__ == "__main__":
|
| 754 |
import uvicorn
|