{
"cells": [
{
"cell_type": "code",
"source": [
"import json\n",
"\n",
"from enum import StrEnum\n",
"\n",
"import numpy as np\n",
"from PIL import Image\n",
"from scipy.special import softmax\n",
"from torch import Tensor\n",
"from torch.jit import RecursiveScriptModule\n",
"from torchvision import transforms\n",
"\n",
"from IPython.display import display, HTML\n",
"\n",
"import torch"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2025-11-10T21:15:01.207611Z",
"start_time": "2025-11-10T21:15:01.205436Z"
}
},
"id": "15c69cabd27e8626",
"outputs": [],
"execution_count": 10
},
{
"cell_type": "code",
"source": [
"class PHOTO_FRAMING(StrEnum):\n",
" FRONT = \"front\"\n",
" BACK = \"back\"\n",
" SIDE_VIEW = \"side_view\"\n",
" THREE_FOURTH = \"three_fourth\"\n",
" INSIDE = \"inside\"\n",
" CLOSEUP = \"closeup\"\n",
" LABEL = \"label\"\n",
" OTHERS = \"others\"\n",
"\n",
"\n",
"AiPredictionByFraming = list[dict[PHOTO_FRAMING, float]]"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2025-11-10T21:15:01.265847Z",
"start_time": "2025-11-10T21:15:01.263362Z"
}
},
"id": "f1cffed5763919d5",
"outputs": [],
"execution_count": 12
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-11-10T21:15:01.312581Z",
"start_time": "2025-11-10T21:15:01.308237Z"
}
},
"cell_type": "code",
"source": [
"def normalize_neural_network_output(neural_network_output: np.ndarray) -> list[float]:\n",
" return softmax(neural_network_output).tolist()\n",
"\n",
"\n",
"def predict_neural_network(neural_network_input: Tensor, neural_network_model: RecursiveScriptModule) -> list[\n",
" list[float]]:\n",
" with torch.no_grad():\n",
" all_outputs = neural_network_model(neural_network_input).tolist()\n",
"\n",
" normalized_neural_network_output = list(map(normalize_neural_network_output, all_outputs))\n",
" return normalized_neural_network_output\n",
"\n",
"\n",
"def format_output(neural_network_output: list, labels_to_output_index: list) -> list[dict]:\n",
" results = []\n",
" for i, a_nn_output in enumerate(neural_network_output):\n",
" results.append(\n",
" {\n",
" 'probabilities_neural_network': {\n",
" labels_to_output_index[j]: p for j, p in enumerate(a_nn_output)\n",
" },\n",
" })\n",
" return results\n",
"\n",
"\n",
"def format_pils_images(pil_images: list[Image.Image]) -> Tensor:\n",
" pil_images_transformed = []\n",
"\n",
" transform = transforms.Compose([\n",
" transforms.Resize((256, 256)),\n",
" transforms.ToTensor(),\n",
" transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n",
" ])\n",
"\n",
" for a_pil_image in pil_images:\n",
" a_pil_image = transform(a_pil_image)\n",
" pil_images_transformed.append(a_pil_image.unsqueeze(0))\n",
"\n",
" return torch.cat(pil_images_transformed, dim=0)\n",
"\n",
"\n",
"def predict(pil_images: list[Image.Image],\n",
" labels_to_output_index: list,\n",
" neural_network_model: RecursiveScriptModule\n",
") -> list[dict]:\n",
" neural_network_input = format_pils_images(pil_images)\n",
" neural_network_output = predict_neural_network(neural_network_input, neural_network_model)\n",
" return format_output(neural_network_output, labels_to_output_index)\n"
],
"id": "7a05d1015a88e18c",
"outputs": [],
"execution_count": 13
},
{
"cell_type": "code",
"source": [
"neural_network_model = torch.jit.load('shared/model_scripted.pt')\n",
"\n",
"with open('shared/labels_to_output_index.json', 'r') as fp:\n",
" labels_to_output_index = json.load(fp)\n",
"\n",
"\n",
"def get_picture_framing_prediction(images_pil: list[Image.Image]) -> AiPredictionByFraming:\n",
" ai_prediction_scores = predict(images_pil,\n",
" labels_to_output_index,\n",
" neural_network_model)\n",
" ai_prediction_by_framing = [p['probabilities_neural_network'] for p in ai_prediction_scores]\n",
"\n",
" return ai_prediction_by_framing"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2025-11-10T21:15:01.535844Z",
"start_time": "2025-11-10T21:15:01.355615Z"
}
},
"id": "41d6a1305c8f992b",
"outputs": [],
"execution_count": 14
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-11-10T21:15:01.575230Z",
"start_time": "2025-11-10T21:15:01.563426Z"
}
},
"cell_type": "code",
"source": [
"image_paths = [\"assets/dress-drm-free/1.jpeg\",\n",
" \"assets/dress-drm-free/2.jpeg\",\n",
" \"assets/dress-drm-free/3.jpeg\",\n",
" \"assets/dress-drm-free/4.jpeg\",\n",
" \"assets/label-difficult/1.jpeg\",\n",
" \"assets/label-difficult/2.jpeg\",\n",
" \"assets/label-difficult/3.jpeg\",\n",
" \"assets/label-difficult/4.jpeg\",\n",
" \"assets/label-difficult/5.jpeg\",\n",
" \"assets/saint-james-coat/1.jpeg\",\n",
" \"assets/saint-james-coat/2.jpeg\",\n",
" \"assets/saint-james-coat/3.jpeg\",\n",
" \"assets/saint-james-coat/4.jpeg\",\n",
" \"assets/saint-james-coat/5.jpeg\"]\n",
"input_images = [Image.open(one_image_path) for one_image_path in image_paths]"
],
"id": "3f7877aba0b830c9",
"outputs": [],
"execution_count": 15
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-11-10T21:15:02.393613Z",
"start_time": "2025-11-10T21:15:01.613188Z"
}
},
"cell_type": "code",
"source": [
"predictions = get_picture_framing_prediction(input_images)\n",
"prediction_texts = []\n",
"\n",
"for one_image_prediction in predictions:\n",
" max_key = max(one_image_prediction, key=one_image_prediction.get)\n",
" max_value = one_image_prediction[max_key]\n",
"\n",
" one_image_predictions = []\n",
" for key, value in one_image_prediction.items():\n",
" one_image_text = f\"{key}: {value:.3f}\"\n",
" if key == max_key:\n",
" one_image_text = f\"{one_image_text}\"\n",
" one_image_predictions.append(one_image_text)\n",
"\n",
" prediction_texts.append(\"
\".join(one_image_predictions))"
],
"id": "cda1d624454cda79",
"outputs": [],
"execution_count": 16
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-11-10T21:15:02.400026Z",
"start_time": "2025-11-10T21:15:02.396874Z"
}
},
"cell_type": "code",
"source": [
"html = \"
{text}
\n", "
\n",
" back: 0.000
closeup: 0.000
front: 1.000
inside: 0.000
label: 0.000
others: 0.000
side_view: 0.000
three_fourth: 0.000
\n",
" back: 0.001
closeup: 0.000
front: 0.008
inside: 0.000
label: 0.000
others: 0.000
side_view: 0.991
three_fourth: 0.000
\n",
" back: 0.014
closeup: 0.000
front: 0.000
inside: 0.000
label: 0.000
others: 0.000
side_view: 0.001
three_fourth: 0.985
\n",
" back: 0.000
closeup: 1.000
front: 0.000
inside: 0.000
label: 0.000
others: 0.000
side_view: 0.000
three_fourth: 0.000
\n",
" back: 0.003
closeup: 0.000
front: 0.997
inside: 0.000
label: 0.000
others: 0.000
side_view: 0.000
three_fourth: 0.000
\n",
" back: 0.943
closeup: 0.000
front: 0.057
inside: 0.000
label: 0.000
others: 0.000
side_view: 0.000
three_fourth: 0.000
\n",
" back: 0.192
closeup: 0.315
front: 0.442
inside: 0.000
label: 0.048
others: 0.001
side_view: 0.000
three_fourth: 0.002
\n",
" back: 1.000
closeup: 0.000
front: 0.000
inside: 0.000
label: 0.000
others: 0.000
side_view: 0.000
three_fourth: 0.000
\n",
" back: 0.000
closeup: 0.971
front: 0.000
inside: 0.000
label: 0.028
others: 0.000
side_view: 0.000
three_fourth: 0.000
\n",
" back: 0.000
closeup: 0.000
front: 0.932
inside: 0.000
label: 0.000
others: 0.065
side_view: 0.002
three_fourth: 0.000
\n",
" back: 0.327
closeup: 0.014
front: 0.376
inside: 0.001
label: 0.000
others: 0.000
side_view: 0.271
three_fourth: 0.011
\n",
" back: 0.000
closeup: 1.000
front: 0.000
inside: 0.000
label: 0.000
others: 0.000
side_view: 0.000
three_fourth: 0.000
\n",
" back: 0.000
closeup: 0.266
front: 0.004
inside: 0.000
label: 0.729
others: 0.000
side_view: 0.000
three_fourth: 0.000
\n",
" back: 0.000
closeup: 0.000
front: 0.000
inside: 0.000
label: 1.000
others: 0.000
side_view: 0.000
three_fourth: 0.000