Add wardrobe_description support for better AI understanding
Browse files- models.py +1 -0
- routes.py +34 -3
- test_api.sh +85 -0
- test_local.sh +80 -0
- wardrobe.py +2 -2
models.py
CHANGED
|
@@ -13,6 +13,7 @@ class ChatRequest(BaseModel):
|
|
| 13 |
message: str
|
| 14 |
session_id: Optional[str] = "default"
|
| 15 |
wardrobe: Optional[List[WardrobeItem]] = None
|
|
|
|
| 16 |
images: Optional[List[str]] = None
|
| 17 |
|
| 18 |
class ChatResponse(BaseModel):
|
|
|
|
| 13 |
message: str
|
| 14 |
session_id: Optional[str] = "default"
|
| 15 |
wardrobe: Optional[List[WardrobeItem]] = None
|
| 16 |
+
wardrobe_description: Optional[str] = None
|
| 17 |
images: Optional[List[str]] = None
|
| 18 |
|
| 19 |
class ChatResponse(BaseModel):
|
routes.py
CHANGED
|
@@ -219,7 +219,9 @@ def setup_routes(app):
|
|
| 219 |
|
| 220 |
if request.wardrobe and len(request.wardrobe) > 0:
|
| 221 |
print(f"[WARDROBE CHAT] ===== WARDROBE REQUEST DETECTED =====")
|
| 222 |
-
|
|
|
|
|
|
|
| 223 |
|
| 224 |
conv_context = get_conversation_context(session_id)
|
| 225 |
|
|
@@ -385,6 +387,7 @@ def setup_routes(app):
|
|
| 385 |
message: str = Form(...),
|
| 386 |
session_id: str = Form(default="default"),
|
| 387 |
wardrobe: Optional[str] = Form(default=None),
|
|
|
|
| 388 |
images: List[UploadFile] = File(default=[])
|
| 389 |
):
|
| 390 |
try:
|
|
@@ -411,6 +414,7 @@ def setup_routes(app):
|
|
| 411 |
message=message,
|
| 412 |
session_id=session_id,
|
| 413 |
wardrobe=wardrobe_items if wardrobe_items else None,
|
|
|
|
| 414 |
images=image_data_urls if image_data_urls else None
|
| 415 |
)
|
| 416 |
|
|
@@ -428,6 +432,7 @@ def setup_routes(app):
|
|
| 428 |
message: str = Form(...),
|
| 429 |
session_id: str = Form(default="default"),
|
| 430 |
wardrobe: Optional[str] = Form(default=None),
|
|
|
|
| 431 |
images: List[UploadFile] = File(default=[])
|
| 432 |
):
|
| 433 |
image_data_urls = []
|
|
@@ -440,9 +445,36 @@ def setup_routes(app):
|
|
| 440 |
image_data_urls.append(data_url)
|
| 441 |
print(f"[STREAM UPLOAD] Processed image: {img_file.filename} ({len(content)} bytes)")
|
| 442 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 443 |
rag_chunks = retrieve_relevant_context(message, top_k=3)
|
| 444 |
rag_context = format_rag_context(rag_chunks)
|
| 445 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 446 |
print(f"[STREAM UPLOAD] Starting streaming response for: {message[:50]}...")
|
| 447 |
|
| 448 |
async def generate():
|
|
@@ -450,7 +482,7 @@ def setup_routes(app):
|
|
| 450 |
|
| 451 |
full_response = ""
|
| 452 |
async for chunk in generate_chat_response_streaming(
|
| 453 |
-
prompt=
|
| 454 |
max_length=512,
|
| 455 |
temperature=0.7,
|
| 456 |
rag_context=rag_context,
|
|
@@ -461,7 +493,6 @@ def setup_routes(app):
|
|
| 461 |
|
| 462 |
yield f"data: {json.dumps({'type': 'end', 'full_response': full_response, 'session_id': session_id})}\n\n"
|
| 463 |
print(f"[STREAM UPLOAD] Streaming complete: {len(full_response)} chars")
|
| 464 |
-
print(f"[STREAM RESPONSE] {full_response}")
|
| 465 |
|
| 466 |
return StreamingResponse(
|
| 467 |
generate(),
|
|
|
|
| 219 |
|
| 220 |
if request.wardrobe and len(request.wardrobe) > 0:
|
| 221 |
print(f"[WARDROBE CHAT] ===== WARDROBE REQUEST DETECTED =====")
|
| 222 |
+
if request.wardrobe_description:
|
| 223 |
+
print(f"[WARDROBE CHAT] Using provided wardrobe description ({len(request.wardrobe_description)} chars)")
|
| 224 |
+
return await handle_wardrobe_chat(message, request.wardrobe, session_id, images=request.images, wardrobe_description=request.wardrobe_description)
|
| 225 |
|
| 226 |
conv_context = get_conversation_context(session_id)
|
| 227 |
|
|
|
|
| 387 |
message: str = Form(...),
|
| 388 |
session_id: str = Form(default="default"),
|
| 389 |
wardrobe: Optional[str] = Form(default=None),
|
| 390 |
+
wardrobe_description: Optional[str] = Form(default=None),
|
| 391 |
images: List[UploadFile] = File(default=[])
|
| 392 |
):
|
| 393 |
try:
|
|
|
|
| 414 |
message=message,
|
| 415 |
session_id=session_id,
|
| 416 |
wardrobe=wardrobe_items if wardrobe_items else None,
|
| 417 |
+
wardrobe_description=wardrobe_description if wardrobe_description and wardrobe_description.strip() else None,
|
| 418 |
images=image_data_urls if image_data_urls else None
|
| 419 |
)
|
| 420 |
|
|
|
|
| 432 |
message: str = Form(...),
|
| 433 |
session_id: str = Form(default="default"),
|
| 434 |
wardrobe: Optional[str] = Form(default=None),
|
| 435 |
+
wardrobe_description: Optional[str] = Form(default=None),
|
| 436 |
images: List[UploadFile] = File(default=[])
|
| 437 |
):
|
| 438 |
image_data_urls = []
|
|
|
|
| 445 |
image_data_urls.append(data_url)
|
| 446 |
print(f"[STREAM UPLOAD] Processed image: {img_file.filename} ({len(content)} bytes)")
|
| 447 |
|
| 448 |
+
wardrobe_items = []
|
| 449 |
+
if wardrobe and wardrobe.strip() and wardrobe.strip() not in ["[]", "", "string"]:
|
| 450 |
+
try:
|
| 451 |
+
wardrobe_data = json.loads(wardrobe)
|
| 452 |
+
if isinstance(wardrobe_data, list):
|
| 453 |
+
wardrobe_items = [WardrobeItem(**item) for item in wardrobe_data]
|
| 454 |
+
except json.JSONDecodeError:
|
| 455 |
+
print(f"[STREAM UPLOAD] Ignoring invalid wardrobe value: {wardrobe[:50]}")
|
| 456 |
+
|
| 457 |
rag_chunks = retrieve_relevant_context(message, top_k=3)
|
| 458 |
rag_context = format_rag_context(rag_chunks)
|
| 459 |
|
| 460 |
+
wardrobe_context = ""
|
| 461 |
+
if wardrobe_description and wardrobe_description.strip():
|
| 462 |
+
wardrobe_context = wardrobe_description
|
| 463 |
+
print(f"[STREAM UPLOAD] Using provided wardrobe description ({len(wardrobe_context)} chars)")
|
| 464 |
+
elif wardrobe_items:
|
| 465 |
+
from wardrobe import format_wardrobe_for_prompt
|
| 466 |
+
wardrobe_context = format_wardrobe_for_prompt(wardrobe_items)
|
| 467 |
+
print(f"[STREAM UPLOAD] Generated wardrobe context ({len(wardrobe_context)} chars)")
|
| 468 |
+
|
| 469 |
+
if wardrobe_context:
|
| 470 |
+
prompt = f"""{wardrobe_context}
|
| 471 |
+
|
| 472 |
+
User request: {message}
|
| 473 |
+
|
| 474 |
+
Suggest a complete outfit using ONLY the items listed above. Reference items by their exact names. Include accessories if available. Be friendly and conversational."""
|
| 475 |
+
else:
|
| 476 |
+
prompt = message
|
| 477 |
+
|
| 478 |
print(f"[STREAM UPLOAD] Starting streaming response for: {message[:50]}...")
|
| 479 |
|
| 480 |
async def generate():
|
|
|
|
| 482 |
|
| 483 |
full_response = ""
|
| 484 |
async for chunk in generate_chat_response_streaming(
|
| 485 |
+
prompt=prompt,
|
| 486 |
max_length=512,
|
| 487 |
temperature=0.7,
|
| 488 |
rag_context=rag_context,
|
|
|
|
| 493 |
|
| 494 |
yield f"data: {json.dumps({'type': 'end', 'full_response': full_response, 'session_id': session_id})}\n\n"
|
| 495 |
print(f"[STREAM UPLOAD] Streaming complete: {len(full_response)} chars")
|
|
|
|
| 496 |
|
| 497 |
return StreamingResponse(
|
| 498 |
generate(),
|
test_api.sh
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
if [ -z "$1" ]; then
|
| 4 |
+
BASE_URL="https://nexusbert-style.hf.space"
|
| 5 |
+
else
|
| 6 |
+
BASE_URL="$1"
|
| 7 |
+
fi
|
| 8 |
+
|
| 9 |
+
echo "=== Testing Style GPT API ==="
|
| 10 |
+
echo "Base URL: $BASE_URL"
|
| 11 |
+
echo "Usage: $0 [base_url]"
|
| 12 |
+
echo "Example: $0 http://localhost:7860"
|
| 13 |
+
echo ""
|
| 14 |
+
|
| 15 |
+
echo "1. Testing GET / (Root endpoint)"
|
| 16 |
+
curl -X GET "$BASE_URL/" \
|
| 17 |
+
-H "Content-Type: application/json" \
|
| 18 |
+
-w "\nHTTP Status: %{http_code}\n\n"
|
| 19 |
+
|
| 20 |
+
echo "2. Testing GET /health"
|
| 21 |
+
curl -X GET "$BASE_URL/health" \
|
| 22 |
+
-H "Content-Type: application/json" \
|
| 23 |
+
-w "\nHTTP Status: %{http_code}\n\n"
|
| 24 |
+
|
| 25 |
+
echo "3. Testing POST /text (Text-only chat)"
|
| 26 |
+
curl -X POST "$BASE_URL/text" \
|
| 27 |
+
-H "Content-Type: application/json" \
|
| 28 |
+
-d '{
|
| 29 |
+
"message": "Hello, what colors go well with blue?",
|
| 30 |
+
"session_id": "test-session-1"
|
| 31 |
+
}' \
|
| 32 |
+
-w "\nHTTP Status: %{http_code}\n\n"
|
| 33 |
+
|
| 34 |
+
echo "4. Testing POST /chat (Chat with optional images - text only)"
|
| 35 |
+
curl -X POST "$BASE_URL/chat" \
|
| 36 |
+
-H "Content-Type: application/json" \
|
| 37 |
+
-d '{
|
| 38 |
+
"message": "What should I wear with a black jacket?",
|
| 39 |
+
"session_id": "test-session-2",
|
| 40 |
+
"images": null
|
| 41 |
+
}' \
|
| 42 |
+
-w "\nHTTP Status: %{http_code}\n\n"
|
| 43 |
+
|
| 44 |
+
echo "5. Testing POST /chat (Chat with wardrobe)"
|
| 45 |
+
curl -X POST "$BASE_URL/chat" \
|
| 46 |
+
-H "Content-Type: application/json" \
|
| 47 |
+
-d '{
|
| 48 |
+
"message": "Suggest an outfit for a casual meeting",
|
| 49 |
+
"session_id": "test-session-3",
|
| 50 |
+
"wardrobe": [
|
| 51 |
+
{
|
| 52 |
+
"category": "shirt",
|
| 53 |
+
"style": "casual",
|
| 54 |
+
"color": "white",
|
| 55 |
+
"brand": "Zara",
|
| 56 |
+
"name": "White casual shirt"
|
| 57 |
+
},
|
| 58 |
+
{
|
| 59 |
+
"category": "pants",
|
| 60 |
+
"style": "formal",
|
| 61 |
+
"color": "navy",
|
| 62 |
+
"brand": "H&M",
|
| 63 |
+
"name": "Navy trousers"
|
| 64 |
+
}
|
| 65 |
+
]
|
| 66 |
+
}' \
|
| 67 |
+
-w "\nHTTP Status: %{http_code}\n\n"
|
| 68 |
+
|
| 69 |
+
echo "6. Testing POST /chat/upload (File upload - text only)"
|
| 70 |
+
curl -X POST "$BASE_URL/chat/upload" \
|
| 71 |
+
-F "message=What colors match with red?" \
|
| 72 |
+
-F "session_id=test-session-4" \
|
| 73 |
+
-F "wardrobe=[]" \
|
| 74 |
+
-w "\nHTTP Status: %{http_code}\n\n"
|
| 75 |
+
|
| 76 |
+
echo "7. Testing POST /chat/upload/stream (Streaming - text only)"
|
| 77 |
+
curl -X POST "$BASE_URL/chat/upload/stream" \
|
| 78 |
+
-F "message=Tell me about fashion trends" \
|
| 79 |
+
-F "session_id=test-session-5" \
|
| 80 |
+
-F "wardrobe=[]" \
|
| 81 |
+
--no-buffer \
|
| 82 |
+
-w "\nHTTP Status: %{http_code}\n\n"
|
| 83 |
+
|
| 84 |
+
echo "=== All tests completed ==="
|
| 85 |
+
|
test_local.sh
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
BASE_URL="http://localhost:7860"
|
| 4 |
+
|
| 5 |
+
echo "=== Testing Style GPT API (Local) ==="
|
| 6 |
+
echo "Base URL: $BASE_URL"
|
| 7 |
+
echo "Make sure the server is running: uvicorn app:app --host 0.0.0.0 --port 7860"
|
| 8 |
+
echo ""
|
| 9 |
+
|
| 10 |
+
echo "1. Testing GET / (Root endpoint)"
|
| 11 |
+
curl -X GET "$BASE_URL/" \
|
| 12 |
+
-H "Content-Type: application/json" \
|
| 13 |
+
-w "\nHTTP Status: %{http_code}\n\n"
|
| 14 |
+
|
| 15 |
+
echo "2. Testing GET /health"
|
| 16 |
+
curl -X GET "$BASE_URL/health" \
|
| 17 |
+
-H "Content-Type: application/json" \
|
| 18 |
+
-w "\nHTTP Status: %{http_code}\n\n"
|
| 19 |
+
|
| 20 |
+
echo "3. Testing POST /text (Text-only chat)"
|
| 21 |
+
curl -X POST "$BASE_URL/text" \
|
| 22 |
+
-H "Content-Type: application/json" \
|
| 23 |
+
-d '{
|
| 24 |
+
"message": "Hello, what colors go well with blue?",
|
| 25 |
+
"session_id": "test-session-1"
|
| 26 |
+
}' \
|
| 27 |
+
-w "\nHTTP Status: %{http_code}\n\n"
|
| 28 |
+
|
| 29 |
+
echo "4. Testing POST /chat (Chat with optional images - text only)"
|
| 30 |
+
curl -X POST "$BASE_URL/chat" \
|
| 31 |
+
-H "Content-Type: application/json" \
|
| 32 |
+
-d '{
|
| 33 |
+
"message": "What should I wear with a black jacket?",
|
| 34 |
+
"session_id": "test-session-2",
|
| 35 |
+
"images": null
|
| 36 |
+
}' \
|
| 37 |
+
-w "\nHTTP Status: %{http_code}\n\n"
|
| 38 |
+
|
| 39 |
+
echo "5. Testing POST /chat (Chat with wardrobe)"
|
| 40 |
+
curl -X POST "$BASE_URL/chat" \
|
| 41 |
+
-H "Content-Type: application/json" \
|
| 42 |
+
-d '{
|
| 43 |
+
"message": "Suggest an outfit for a casual meeting",
|
| 44 |
+
"session_id": "test-session-3",
|
| 45 |
+
"wardrobe": [
|
| 46 |
+
{
|
| 47 |
+
"category": "shirt",
|
| 48 |
+
"style": "casual",
|
| 49 |
+
"color": "white",
|
| 50 |
+
"brand": "Zara",
|
| 51 |
+
"name": "White casual shirt"
|
| 52 |
+
},
|
| 53 |
+
{
|
| 54 |
+
"category": "pants",
|
| 55 |
+
"style": "formal",
|
| 56 |
+
"color": "navy",
|
| 57 |
+
"brand": "H&M",
|
| 58 |
+
"name": "Navy trousers"
|
| 59 |
+
}
|
| 60 |
+
]
|
| 61 |
+
}' \
|
| 62 |
+
-w "\nHTTP Status: %{http_code}\n\n"
|
| 63 |
+
|
| 64 |
+
echo "6. Testing POST /chat/upload (File upload - text only)"
|
| 65 |
+
curl -X POST "$BASE_URL/chat/upload" \
|
| 66 |
+
-F "message=What colors match with red?" \
|
| 67 |
+
-F "session_id=test-session-4" \
|
| 68 |
+
-F "wardrobe=[]" \
|
| 69 |
+
-w "\nHTTP Status: %{http_code}\n\n"
|
| 70 |
+
|
| 71 |
+
echo "7. Testing POST /chat/upload/stream (Streaming - text only)"
|
| 72 |
+
curl -X POST "$BASE_URL/chat/upload/stream" \
|
| 73 |
+
-F "message=Tell me about fashion trends" \
|
| 74 |
+
-F "session_id=test-session-5" \
|
| 75 |
+
-F "wardrobe=[]" \
|
| 76 |
+
--no-buffer \
|
| 77 |
+
-w "\nHTTP Status: %{http_code}\n\n"
|
| 78 |
+
|
| 79 |
+
echo "=== All tests completed ==="
|
| 80 |
+
|
wardrobe.py
CHANGED
|
@@ -62,11 +62,11 @@ def format_wardrobe_for_prompt(wardrobe: List[WardrobeItem]) -> str:
|
|
| 62 |
|
| 63 |
Categories: {categories_list}"""
|
| 64 |
|
| 65 |
-
async def handle_wardrobe_chat(message: str, wardrobe: List[WardrobeItem], session_id: str, images: Optional[List[str]] = None) -> ChatResponse:
|
| 66 |
conv_context = get_conversation_context(session_id)
|
| 67 |
enhanced_message = enhance_message_with_context(message, conv_context["context"])
|
| 68 |
|
| 69 |
-
wardrobe_context = format_wardrobe_for_prompt(wardrobe)
|
| 70 |
|
| 71 |
wardrobe_by_category = {}
|
| 72 |
for item in wardrobe:
|
|
|
|
| 62 |
|
| 63 |
Categories: {categories_list}"""
|
| 64 |
|
| 65 |
+
async def handle_wardrobe_chat(message: str, wardrobe: List[WardrobeItem], session_id: str, images: Optional[List[str]] = None, wardrobe_description: Optional[str] = None) -> ChatResponse:
|
| 66 |
conv_context = get_conversation_context(session_id)
|
| 67 |
enhanced_message = enhance_message_with_context(message, conv_context["context"])
|
| 68 |
|
| 69 |
+
wardrobe_context = wardrobe_description if wardrobe_description else format_wardrobe_for_prompt(wardrobe)
|
| 70 |
|
| 71 |
wardrobe_by_category = {}
|
| 72 |
for item in wardrobe:
|