Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +6 -0
- Tipsomaly/model/big_vision/configs/proj/givt/givt_overview.png +3 -0
- Tipsomaly/model/big_vision/configs/proj/jetformer/jetformer_overview.png +3 -0
- Tipsomaly/model/big_vision/configs/proj/paligemma/paligemma.png +3 -0
- Tipsomaly/model/big_vision/datasets/countbenchqa/data/countbench_paired_questions.json +1 -0
- Tipsomaly/model/big_vision/datasets/imagenet/__pycache__/class_names.cpython-311.pyc +3 -0
- Tipsomaly/model/big_vision/datasets/imagenet/__pycache__/class_names.cpython-312.pyc +3 -0
- Tipsomaly/model/big_vision/datasets/imagenet/__pycache__/class_names.cpython-39.pyc +3 -0
- Tipsomaly/model/big_vision/datasets/nocaps/nocaps.py +160 -0
- Tipsomaly/model/big_vision/datasets/refcoco/refcoco.py +448 -0
- Tipsomaly/model/big_vision/datasets/rsvqa_hr/rsvqa_hr.py +193 -0
- Tipsomaly/model/big_vision/datasets/rsvqa_lr/rsvqa_lr.py +198 -0
- Tipsomaly/model/big_vision/datasets/scicap/scicap.py +205 -0
- Tipsomaly/model/big_vision/datasets/science_qa/science_qa.py +156 -0
- Tipsomaly/model/big_vision/datasets/screen2words/screen2words.py +120 -0
- Tipsomaly/model/big_vision/datasets/stvqa/stvqa.py +134 -0
- Tipsomaly/model/big_vision/datasets/tallyqa/tallyqa.py +146 -0
- Tipsomaly/model/big_vision/datasets/textcaps/textcaps.py +152 -0
- Tipsomaly/model/big_vision/datasets/textvqa/textvqa.py +186 -0
- Tipsomaly/model/big_vision/datasets/vizwizvqa/vizwizvqa.py +128 -0
- Tipsomaly/model/big_vision/datasets/vqa/vqa.py +147 -0
- Tipsomaly/model/big_vision/datasets/widgetcap/widgetcap.py +151 -0
- Tipsomaly/model/big_vision/datasets/xgqa/xgqa.py +145 -0
- Tipsomaly/model/big_vision/datasets/xm3600/xm3600.py +136 -0
- Tipsomaly/model/big_vision/evaluators/proj/cappa/perplexity.py +50 -0
- Tipsomaly/model/big_vision/evaluators/proj/cappa/scoring_classifier.py +63 -0
- Tipsomaly/model/big_vision/evaluators/proj/distill/distance.py +151 -0
- Tipsomaly/model/big_vision/evaluators/proj/givt/coco_panoptic.py +401 -0
- Tipsomaly/model/big_vision/evaluators/proj/givt/nyu_depth.py +191 -0
- Tipsomaly/model/big_vision/evaluators/proj/givt/save_predictions.py +118 -0
- Tipsomaly/model/big_vision/evaluators/proj/image_text/contrastive.py +99 -0
- Tipsomaly/model/big_vision/evaluators/proj/image_text/discriminative_classifier.py +440 -0
- Tipsomaly/model/big_vision/evaluators/proj/image_text/discriminative_classifier_test.py +237 -0
- Tipsomaly/model/big_vision/evaluators/proj/image_text/image_text_retrieval.py +85 -0
- Tipsomaly/model/big_vision/evaluators/proj/image_text/image_text_retrieval_test.py +86 -0
- Tipsomaly/model/big_vision/evaluators/proj/image_text/prompt_engineering.py +112 -0
- Tipsomaly/model/big_vision/evaluators/proj/image_text/prompt_engineering_constants.py +110 -0
- Tipsomaly/model/big_vision/evaluators/proj/image_text/prompt_engineering_test.py +48 -0
- Tipsomaly/model/big_vision/evaluators/proj/image_text/retrieval.py +306 -0
- Tipsomaly/model/big_vision/evaluators/proj/image_text/retrieval_test.py +178 -0
- Tipsomaly/model/big_vision/evaluators/proj/paligemma/perplexity.py +63 -0
- Tipsomaly/model/big_vision/evaluators/proj/paligemma/transfers/chartqa.py +139 -0
- Tipsomaly/model/big_vision/evaluators/proj/paligemma/transfers/pope.py +135 -0
- Tipsomaly/model/big_vision/evaluators/proj/uvim/coco_panoptic.py +324 -0
- Tipsomaly/model/big_vision/evaluators/proj/uvim/coltran_fid.py +242 -0
- Tipsomaly/model/big_vision/evaluators/proj/uvim/coltran_fid_data/eval_file_names.txt +0 -0
- Tipsomaly/model/big_vision/evaluators/proj/uvim/coltran_fid_data/reference_file_names.txt +0 -0
- Tipsomaly/model/big_vision/evaluators/proj/uvim/common.py +64 -0
- Tipsomaly/model/big_vision/evaluators/proj/uvim/compute_mean.py +79 -0
- Tipsomaly/model/big_vision/evaluators/proj/uvim/nyu_depth.py +154 -0
.gitattributes
CHANGED
|
@@ -36,3 +36,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 36 |
Tipsomaly/imgs/Models_Architecture_page-0001.jpg filter=lfs diff=lfs merge=lfs -text
|
| 37 |
Tipsomaly/imgs/results-table.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
Tipsomaly/imgs/Qualitative_results_page-0001.jpg filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
Tipsomaly/imgs/Models_Architecture_page-0001.jpg filter=lfs diff=lfs merge=lfs -text
|
| 37 |
Tipsomaly/imgs/results-table.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
Tipsomaly/imgs/Qualitative_results_page-0001.jpg filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
Tipsomaly/model/big_vision/datasets/imagenet/__pycache__/class_names.cpython-312.pyc filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
Tipsomaly/model/big_vision/datasets/imagenet/__pycache__/class_names.cpython-39.pyc filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
Tipsomaly/model/big_vision/datasets/imagenet/__pycache__/class_names.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
Tipsomaly/model/big_vision/configs/proj/givt/givt_overview.png filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
Tipsomaly/model/big_vision/configs/proj/jetformer/jetformer_overview.png filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
Tipsomaly/model/big_vision/configs/proj/paligemma/paligemma.png filter=lfs diff=lfs merge=lfs -text
|
Tipsomaly/model/big_vision/configs/proj/givt/givt_overview.png
ADDED
|
Git LFS Details
|
Tipsomaly/model/big_vision/configs/proj/jetformer/jetformer_overview.png
ADDED
|
Git LFS Details
|
Tipsomaly/model/big_vision/configs/proj/paligemma/paligemma.png
ADDED
|
Git LFS Details
|
Tipsomaly/model/big_vision/datasets/countbenchqa/data/countbench_paired_questions.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
[{"question": "How many headsets are there in the image?"}, {"question": "How many light bulbs are there in the image?"}, {"question": "How many prints are there in the image?"}, {"question": "How many arrows are there in the image?"}, {"question": "How many spoons are there in the image?"}, {"question": "How many girls are there in the image?"}, {"question": "How many parrots are there in the image?"}, {"question": "How many coloring pages are there in the image?"}, {"question": "How many food containers are there in the image?"}, {"question": "How many birdhouse patterns are there in the image?"}, {"question": "How many sofas are there in the image?"}, {"question": "How many waterlilies are there in the image?"}, {"question": "How many chairs are there in the image?"}, {"question": "How many eggs are there in the image?"}, {"question": "How many golfers are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many outfits are there in the image?"}, {"question": "How many pigs are there in the image?"}, {"question": "How many cars are there in the image?"}, {"question": "How many aum symbols are there in the image?"}, {"question": "How many men are there in the image?"}, {"question": "How many chairs are there in the image?"}, {"question": "How many buttons are there in the image?"}, {"question": "How many rackets are there in the image?"}, {"question": "How many pots are there in the image?"}, {"question": "How many stars are there in the image?"}, {"question": "How many chairs are there in the image?"}, {"question": "How many dice are there in the image?"}, {"question": "How many silhouettes are there in the image?"}, {"question": "How many kids are there in the image?"}, {"question": "How many moth silhouettes are there in the image?"}, {"question": "How many pumpkin candles are there in the image?"}, {"question": "How many essential oils are there in the image?"}, {"question": "How many stencils are there in the image?"}, {"question": "How many text boxes are there in the image?"}, {"question": "How many basketball players are there in the image?"}, {"question": "How many photos are there in the image?"}, {"question": "How many forks are there in the image?"}, {"question": "How many icons are there in the image?"}, {"question": "How many bags are there in the image?"}, {"question": "How many couples are there in the image?"}, {"question": "How many weights are there in the image?"}, {"question": "How many women are there in the image?"}, {"question": "How many fish are there in the image?"}, {"question": "How many icons are there in the image?"}, {"question": "How many clocks are there in the image?"}, {"question": "How many banners are there in the image?"}, {"question": "How many chests are there in the image?"}, {"question": "How many stars are there in the image?"}, {"question": "How many men are there in the image?"}, {"question": "How many globe icons are there in the image?"}, {"question": "How many posters are there in the image?"}, {"question": "How many socks are there in the image?"}, {"question": "How many men are there in the image?"}, {"question": "How many snails are there in the image?"}, {"question": "How many crochet potholders are there in the image?"}, {"question": "How many christmas cards are there in the image?"}, {"question": "How many double beds are there in the image?"}, {"question": "How many baseball players are there in the image?"}, {"question": "How many icons are there in the image?"}, {"question": "How many symbols are there in the image?"}, {"question": "How many baseball players are there in the image?"}, {"question": "How many cars are there in the image?"}, {"question": "How many trees are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many tree trunk cuts are there in the image?"}, {"question": "How many women are there in the image?"}, {"question": "How many chairs are there in the image?"}, {"question": "How many icons are there in the image?"}, {"question": "How many banners are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many individual earrings are there in the image?"}, {"question": "How many men are there in the image?"}, {"question": "How many tomatoes are there in the image?"}, {"question": "How many chairs are there in the image?"}, {"question": "How many prints are there in the image?"}, {"question": "How many ice creams are there in the image?"}, {"question": "How many plates are there in the image?"}, {"question": "How many sumo wrestlers are there in the image?"}, {"question": "How many compositions are there in the image?"}, {"question": "How many PVC vinyls are there in the image?"}, {"question": "How many photos of fruit are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many icons are there in the image?"}, {"question": "How many bottles are there in the image?"}, {"question": "How many mirrors are there in the image?"}, {"question": "How many groomsmen are there in the image?"}, {"question": "How many posters are there in the image?"}, {"question": "How many pumpkins are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many kittens are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many boots are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many figurines are there in the image?"}, {"question": "How many moais are there in the image?"}, {"question": "How many icons are there in the image?"}, {"question": "How many cats are there in the image?"}, {"question": "How many wallpaper variants are there in the image?"}, {"question": "How many nail polishes are there in the image?"}, {"question": "How many bumble bees are there in the image?"}, {"question": "How many tickets are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many wine bottles are there in the image?"}, {"question": "How many silhouettes of couples are there in the image?"}, {"question": "How many owls are there in the image?"}, {"question": "How many candles are there in the image?"}, {"question": "How many flower pots are there in the image?"}, {"question": "How many coins are there in the image?"}, {"question": "How many placemats are there in the image?"}, {"question": "How many women are there in the image?"}, {"question": "How many bottles are there in the image?"}, {"question": "How many cars are there in the image?"}, {"question": "How many men are there in the image?"}, {"question": "How many candles are there in the image?"}, {"question": "How many pairs of socks are there in the image?"}, {"question": "How many blinds are there in the image?"}, {"question": "How many floral patterns are there in the image?"}, {"question": "How many icons are there in the image?"}, {"question": "How many dice are there in the image?"}, {"question": "How many bicycles are there in the image?"}, {"question": "How many dwarfs are there in the image?"}, {"question": "How many stickers are there in the image?"}, {"question": "How many symbols are there in the image?"}, {"question": "How many stamps are there in the image?"}, {"question": "How many pumpkins are there in the image?"}, {"question": "How many game cartridges are there in the image?"}, {"question": "How many chairs are there in the image?"}, {"question": "How many glasses are there in the image?"}, {"question": "How many bottles are there in the image?"}, {"question": "How many coins are there in the image?"}, {"question": "How many gears are there in the image?"}, {"question": "How many flowers are there in the image?"}, {"question": "How many crip packages are there in the image?"}, {"question": "How many bridesmaids are there in the image?"}, {"question": "How many apples are there in the image?"}, {"question": "How many bowls are there in the image?"}, {"question": "How many women are there in the image?"}, {"question": "How many photographs are there in the image?"}, {"question": "How many warning signs are there in the image?"}, {"question": "How many children are there in the image?"}, {"question": "How many cyclists are there in the image?"}, {"question": "How many women are there in the image?"}, {"question": "How many chairs are there in the image?"}, {"question": "How many giraffes are there in the image?"}, {"question": "How many men are there in the image?"}, {"question": "How many male nurses are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many couples are there in the image?"}, {"question": "How many cars are there in the image?"}, {"question": "How many apples are there in the image?"}, {"question": "How many smartphones are there in the image?"}, {"question": "How many roses are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many greeting cards are there in the image?"}, {"question": "How many guitars are there in the image?"}, {"question": "How many ironman suits are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many CDs are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many symbols are there in the image?"}, {"question": "How many chairs are there in the image?"}, {"question": "How many pot holders are there in the image?"}, {"question": "How many stamps are there in the image?"}, {"question": "How many bookmarks are there in the image?"}, {"question": "How many portraits are there in the image?"}, {"question": "How many girls are there in the image?"}, {"question": "How many labels are there in the image?"}, {"question": "How many mandalas are there in the image?"}, {"question": "How many silhouettes are there in the image?"}, {"question": "How many peacocks are there in the image?"}, {"question": "How many roses are there in the image?"}, {"question": "How many cars are there in the image?"}, {"question": "How many children are there in the image?"}, {"question": "How many chairs are there in the image?"}, {"question": "How many symbols are there in the image?"}, {"question": "How many canvases are there in the image?"}, {"question": "How many cards are there in the image?"}, {"question": "How many bell pepper halves are there in the image?"}, {"question": "How many pigs are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many newspapers are there in the image?"}, {"question": "How many leggings are there in the image?"}, {"question": "How many medals are there in the image?"}, {"question": "How many patterns are there in the image?"}, {"question": "How many women are there in the image?"}, {"question": "How many doors are there in the image?"}, {"question": "How many pairs of socks are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many backgrounds are there in the image?"}, {"question": "How many images of dogs are there in the image?"}, {"question": "How many broadheads are there in the image?"}, {"question": "How many figurines are there in the image?"}, {"question": "How many candles are there in the image?"}, {"question": "How many petals does each flower have in this image?"}, {"question": "How many crayons are there in the image?"}, {"question": "How many men are there in the image?"}, {"question": "How many caps are there in the image?"}, {"question": "How many children are there in the image?"}, {"question": "How many frames are there in the image?"}, {"question": "How many animals are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many cats are there in the image?"}, {"question": "How many sconces are there in the image?"}, {"question": "How many spoons are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many icons are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many planes are there in the image?"}, {"question": "How many cats are there in the image?"}, {"question": "How many sketches are there in the image?"}, {"question": "How many trees are there in the image?"}, {"question": "How many dice are there in the image?"}, {"question": "How many figurines are there in the image?"}, {"question": "How many napkins are there in the image?"}, {"question": "How many dogs are there in the image?"}, {"question": "How many kids are there in the image?"}, {"question": "How many hearts are there in the image?"}, {"question": "How many apples are there in the image?"}, {"question": "How many post-its are there in the image?"}, {"question": "How many dice are there in the image?"}, {"question": "How many books are there in the image?"}, {"question": "How many women are there in the image?"}, {"question": "How many sofas are there in the image?"}, {"question": "How many dice are there in the image?"}, {"question": "How many girls are there in the image?"}, {"question": "How many sketches are there in the image?"}, {"question": "How many men are there in the image?"}, {"question": "How many moons are there in the image?"}, {"question": "How many labels are there in the image?"}, {"question": "How many cylinders are there in the image?"}, {"question": "How many silhouettes of couples are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many dogs are there in the image?"}, {"question": "How many people on stage are there in the image?"}, {"question": "How many lambs are there in the image?"}, {"question": "How many violins are there in the image?"}, {"question": "How many armchairs are there in the image?"}, {"question": "How many silhouettes are there in the image?"}, {"question": "How many eggs are there in the image?"}, {"question": "How many photos are there in the image?"}, {"question": "How many chairs are there in the image?"}, {"question": "How many figurines are there in the image?"}, {"question": "How many moais are there in the image?"}, {"question": "How many symbols are there in the image?"}, {"question": "How many glasses are there in the image?"}, {"question": "How many sinks are there in the image?"}, {"question": "How many women are there in the image?"}, {"question": "How many buildings are there in the image?"}, {"question": "How many flyers are there in the image?"}, {"question": "How many dice are there in the image?"}, {"question": "How many pizza slices are there in the image?"}, {"question": "How many stamps are there in the image?"}, {"question": "How many stickers are there in the image?"}, {"question": "How many gifts are there in the image?"}, {"question": "How many bowls are there in the image?"}, {"question": "How many onesies are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many posters are there in the image?"}, {"question": "How many cards are there in the image?"}, {"question": "How many portraits are there in the image?"}, {"question": "How many poinsettias are there in the image?"}, {"question": "How many chicken thighs are there in the image?"}, {"question": "How many glass windows are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many baskets are there in the image?"}, {"question": "How many tulips are there in the image?"}, {"question": "How many bottles are there in the image?"}, {"question": "How many stuffed animals are there in the image?"}, {"question": "How many keychains are there in the image?"}, {"question": "How many photos are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many adults are there in the image?"}, {"question": "How many tulips are there in the image?"}, {"question": "How many frames are there in the image?"}, {"question": "How many samosas are there in the image?"}, {"question": "How many strawberries are there in the image?"}, {"question": "How many cocktails are there in the image?"}, {"question": "How many men are there in the image?"}, {"question": "How many wineglasses are there in the image?"}, {"question": "How many goblets are there in the image?"}, {"question": "How many prints are there in the image?"}, {"question": "How many flowers are there in the image?"}, {"question": "How many zebras are there in the image?"}, {"question": "How many paint brushes are there in the image?"}, {"question": "How many prints are there in the image?"}, {"question": "How many figurines are there in the image?"}, {"question": "How many glasses are there in the image?"}, {"question": "How many sunglasses are there in the image?"}, {"question": "How many banners are there in the image?"}, {"question": "How many bottle caps are there in the image?"}, {"question": "How many prints are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many eggs are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many books are there in the image?"}, {"question": "How many animals are there in the image?"}, {"question": "How many eyeshadows are there in the image?"}, {"question": "How many keychains are there in the image?"}, {"question": "How many pairs of earrings are there in the image?"}, {"question": "How many canisters are there in the image?"}, {"question": "How many bags are there in the image?"}, {"question": "How many baking trays are there in the image?"}, {"question": "How many diamonds are there in the image?"}, {"question": "How many portraits are there in the image?"}, {"question": "How many framed images are there in the image?"}, {"question": "How many flags are there in the image?"}, {"question": "How many chairs are there in the image?"}, {"question": "How many framed pictures are there in the image?"}, {"question": "How many eggs are there in the image?"}, {"question": "How many croissants are there in the image?"}, {"question": "How many Manikins are there in the image?"}, {"question": "How many symbols are there in the image?"}, {"question": "How many labels are there in the image?"}, {"question": "How many people in the foreground are there in the image?"}, {"question": "How many armchairs are there in the image?"}, {"question": "How many cups are there in the image?"}, {"question": "How many helicopters are there in the image?"}, {"question": "How many icons are there in the image?"}, {"question": "How many children are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many buttons are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many Cookies are there in the image?"}, {"question": "How many sculptures are there in the image?"}, {"question": "How many school uniforms are there in the image?"}, {"question": "How many sculptures are there in the image?"}, {"question": "How many eggs are there in the image?"}, {"question": "How many packages are there in the image?"}, {"question": "How many bottles are there in the image?"}, {"question": "How many stories does this cottage have?"}, {"question": "How many gift cards are there in the image?"}, {"question": "How many eggs are there in the image?"}, {"question": "How many beers are there in the image?"}, {"question": "How many children are there in the image?"}, {"question": "How many planes are there in the image?"}, {"question": "How many cactus pots are there in the image?"}, {"question": "How many smartphones are there in the image?"}, {"question": "How many picture frames are there in the image?"}, {"question": "How many elephants are there in the image?"}, {"question": "How many guitars are there in the image?"}, {"question": "How many samurai are there in the image?"}, {"question": "How many ghosts are there in the image?"}, {"question": "How many bottles are there in the image?"}, {"question": "How many candles are there in the image?"}, {"question": "How many vases are there in the image?"}, {"question": "How many sets of headphones are there in the image?"}, {"question": "How many pandas are there in the image?"}, {"question": "How many books are there in the image?"}, {"question": "How many stickers are there in the image?"}, {"question": "How many rings are there in the image?"}, {"question": "How many dice are there in the image?"}, {"question": "How many dragon balls are there in the image?"}, {"question": "How many tanks are there in the image?"}, {"question": "How many students are there in the image?"}, {"question": "How many cups are there in the image?"}, {"question": "How many cubs are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many cushions are there in the image?"}, {"question": "How many shoes are there in the image?"}, {"question": "How many beers are there in the image?"}, {"question": "How many wine glasses are there in the image?"}, {"question": "How many cards are there in the image?"}, {"question": "How many boots are there in the image?"}, {"question": "How many stickers are there in the image?"}, {"question": "How many dice are there in the image?"}, {"question": "How many butterflies are there in the image?"}, {"question": "How many women are there in the image?"}, {"question": "How many butterflies are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many chairs are there in the image?"}, {"question": "How many hexagons are there in the image?"}, {"question": "How many chairs are there in the image?"}, {"question": "How many football players are there in the image?"}, {"question": "How many gifts are there in the image?"}, {"question": "How many light bulbs are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many chairs are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many women are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many candles are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many silhouettes are there in the image?"}, {"question": "How many pots are there in the image?"}, {"question": "How many pumpkins are there in the image?"}, {"question": "How many owls are there in the image?"}, {"question": "How many doctors are there in the image?"}, {"question": "How many pigs are there in the image?"}, {"question": "How many pillars are there in the image?"}, {"question": "How many symbols are there in the image?"}, {"question": "How many pumpkins are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many eggs are there in the image?"}, {"question": "How many roses are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many bottles are there in the image?"}, {"question": "How many illustrations are there in the image?"}, {"question": "How many photos are there in the image?"}, {"question": "How many butterflies are there in the image?"}, {"question": "How many spoons are there in the image?"}, {"question": "How many potato spreads are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many women are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many wall arts are there in the image?"}, {"question": "How many covers are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many soldiers are there in the image?"}, {"question": "How many posters are there in the image?"}, {"question": "How many chairs are there in the image?"}, {"question": "How many patterns are there in the image?"}, {"question": "How many dogs are there in the image?"}, {"question": "How many pennies are there in the image?"}, {"question": "How many windows are there in the image?"}, {"question": "How many eggs are there in the image?"}, {"question": "How many geese are there in the image?"}, {"question": "How many tickets are there in the image?"}, {"question": "How many figurines are there in the image?"}, {"question": "How many cases are there in the image?"}, {"question": "How many people are in the foreground of this image?"}, {"question": "How many glasses are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many cups are there in the image?"}, {"question": "How many dice are there in the image?"}, {"question": "How many kittens are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many scones are there in the image?"}, {"question": "How many schoolgirls are there in the image?"}, {"question": "How many padlocks are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many windows are there in the image?"}, {"question": "How many icons are there in the image?"}, {"question": "How many banners are there in the image?"}, {"question": "How many dogs are there in the image?"}, {"question": "How many groomsmen are there in the image?"}, {"question": "How many children are there in the image?"}, {"question": "How many glasses are there in the image?"}, {"question": "How many icons are there in the image?"}, {"question": "How many coins are there in the image?"}, {"question": "How many trees are there in the image?"}, {"question": "How many turtles are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many frames are there in the image?"}, {"question": "How many firescreens are there in the image?"}, {"question": "How many bowls are there in the image?"}, {"question": "How many stickers are there in the image?"}, {"question": "How many archaic mirrors are there in the image?"}, {"question": "How many dogs are there in the image?"}, {"question": "How many candles are there in the image?"}, {"question": "How many paint brushes are there in the image?"}, {"question": "How many forks are there in the image?"}, {"question": "How many figurines are there in the image?"}, {"question": "How many owls are there in the image?"}, {"question": "How many icons are there in the image?"}, {"question": "How many silhouettes are there in the image?"}, {"question": "How many fried eggs are there in the image?"}, {"question": "How many guitars are there in the image?"}, {"question": "How many signs are there in the image?"}, {"question": "How many watches are there in the image?"}, {"question": "How many women are there in the image?"}, {"question": "How many children are there in the image?"}, {"question": "How many kittens are there in the image?"}, {"question": "How many kids are there in the image?"}, {"question": "How many kittens are there in the image?"}, {"question": "How many toys are there in the image?"}, {"question": "How many pairs of socks are there in the image?"}, {"question": "How many pairs of socks are there in the image?"}, {"question": "How many men are there in the image?"}, {"question": "How many stars are there in the image?"}, {"question": "How many quarters are there in the image?"}, {"question": "How many women are there in the image?"}, {"question": "How many chairs are there in the image?"}, {"question": "How many diamonds are there in the image?"}, {"question": "How many moais are there in the image?"}, {"question": "How many banners are there in the image?"}, {"question": "How many silhouettes are there in the image?"}, {"question": "How many coins are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many banners are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many flags are there in the image?"}, {"question": "How many frames are there in the image?"}, {"question": "How many contestants are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many globes are there in the image?"}, {"question": "How many animals are there in the image?"}, {"question": "How many bottles are there in the image?"}, {"question": "How many buttons are there in the image?"}, {"question": "How many icons are there in the image?"}, {"question": "How many piles of candy are there in the image?"}, {"question": "How many pictures are there in the image?"}, {"question": "How many plates are there in the image?"}, {"question": "How many calendars are there in the image?"}, {"question": "How many oranges are there in the image?"}, {"question": "How many puppies are there in the image?"}, {"question": "How many buffalos are there in the image?"}, {"question": "How many bottles are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many icons are there in the image?"}, {"question": "How many christmas balls are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many armchairs are there in the image?"}, {"question": "How many frames are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many soldiers are there in the image?"}, {"question": "How many chairs are there in the image?"}, {"question": "How many chairs are there in the image?"}, {"question": "How many cards are there in the image?"}, {"question": "How many colored tiles are there in the image?"}, {"question": "How many trapezoids are there in the image?"}, {"question": "How many pastries are there in the image?"}, {"question": "How many plants are there in the image?"}, {"question": "How many place mats are there in the image?"}, {"question": "How many symbols are there in the image?"}, {"question": "How many mazes are there in the image?"}, {"question": "How many tea bags are there in the image?"}, {"question": "How many photos are there in the image?"}, {"question": "How many children are there in the image?"}, {"question": "How many starfish are there in the image?"}, {"question": "How many mugs are there in the image?"}, {"question": "How many bottles are there in the image?"}, {"question": "How many eggs are there in the image?"}, {"question": "How many icons are there in the image?"}, {"question": "How many balls are there in the image?"}, {"question": "How many paper bags are there in the image?"}, {"question": "How many garage doors are there in the image?"}, {"question": "How many silhouettes are there in the image?"}, {"question": "How many couples are there in the image?"}, {"question": "How many dice are there in the image?"}, {"question": "How many women are there in the image?"}, {"question": "How many icons are there in the image?"}, {"question": "How many wooden spoons are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many chairs are there in the image?"}, {"question": "How many chairs are there in the image?"}, {"question": "How many beds are there in the image?"}, {"question": "How many chairs are there in the image?"}]
|
Tipsomaly/model/big_vision/datasets/imagenet/__pycache__/class_names.cpython-311.pyc
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:60133d938ba72366456dec978d305ff978b143577b047fe46aee4b50b104eef6
|
| 3 |
+
size 272122
|
Tipsomaly/model/big_vision/datasets/imagenet/__pycache__/class_names.cpython-312.pyc
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a565009d6d842a49ccfc7f31d9945d1916fa9669913290d3483a4b9d360acbd7
|
| 3 |
+
size 272065
|
Tipsomaly/model/big_vision/datasets/imagenet/__pycache__/class_names.cpython-39.pyc
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fba8a44d65d0f986bf5562ae0290a3cec1db9af239f9c0f91876099c173b760b
|
| 3 |
+
size 272046
|
Tipsomaly/model/big_vision/datasets/nocaps/nocaps.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# pylint: disable=line-too-long
|
| 16 |
+
r"""Implements nocaps val/test set in TFDS structure.
|
| 17 |
+
|
| 18 |
+
It's small data, so simple to run locally. First, copy the data to local disk:
|
| 19 |
+
|
| 20 |
+
mkdir -p /tmp/data/nocaps_data
|
| 21 |
+
cd /tmp/data/nocaps_data
|
| 22 |
+
wget https://s3.amazonaws.com/open-images-dataset/tar/test.tar.gz
|
| 23 |
+
wget https://s3.amazonaws.com/open-images-dataset/tar/validation.tar.gz
|
| 24 |
+
curl -O https://nocaps.s3.amazonaws.com/nocaps_val_4500_captions.json
|
| 25 |
+
curl -O https://s3.amazonaws.com/nocaps/nocaps_test_image_info.json
|
| 26 |
+
|
| 27 |
+
mkdir -p /tmp/data/nocaps_data/Images
|
| 28 |
+
tar -xf validation.tar.gz -C Images
|
| 29 |
+
rm validation.tar.gz
|
| 30 |
+
tar -xf test.tar.gz -C Images
|
| 31 |
+
rm test.tar.gz
|
| 32 |
+
|
| 33 |
+
Then, run conversion locally (make sure to install tensorflow-datasets for the `tfds` util):
|
| 34 |
+
|
| 35 |
+
cd big_vision/datasets
|
| 36 |
+
env TFDS_DATA_DIR=/tmp/tfds tfds build --datasets=nocaps
|
| 37 |
+
|
| 38 |
+
Example to load:
|
| 39 |
+
|
| 40 |
+
import tensorflow_datasets as tfds
|
| 41 |
+
dataset = tfds.load('nocaps', split='val', data_dir='/tmp/tfds')
|
| 42 |
+
"""
|
| 43 |
+
import collections
|
| 44 |
+
import json
|
| 45 |
+
import os
|
| 46 |
+
|
| 47 |
+
from absl import logging
|
| 48 |
+
import numpy as np
|
| 49 |
+
import tensorflow as tf
|
| 50 |
+
import tensorflow_datasets as tfds
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
_DESCRIPTION = """Nocaps dataset."""
|
| 54 |
+
|
| 55 |
+
_CITATION = (
|
| 56 |
+
'@inproceedings{agrawal2019nocaps,'
|
| 57 |
+
'title={nocaps: novel object captioning at scale},'
|
| 58 |
+
'author={Agrawal, Harsh and Desai, Karan and Wang, Yufei and Chen, Xinlei'
|
| 59 |
+
'and Jain, Rishabh and Johnson, Mark and Batra, Dhruv and Parikh, Devi'
|
| 60 |
+
'and Lee, Stefan and Anderson, Peter},'
|
| 61 |
+
'booktitle={ICCV},'
|
| 62 |
+
'pages={8948--8957},'
|
| 63 |
+
'year={2019}}')
|
| 64 |
+
|
| 65 |
+
# When running locally (recommended), copy files as above an use these:
|
| 66 |
+
_FILEPATH = '/tmp/data/nocaps_data/Images/'
|
| 67 |
+
_VAL_FILES = '/tmp/data/nocaps_data/nocaps_val_4500_captions.json'
|
| 68 |
+
_TEST_FILES = '/tmp/data/nocaps_data/nocaps_test_image_info.json'
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class NoCaps(tfds.core.GeneratorBasedBuilder):
|
| 72 |
+
"""DatasetBuilder for nocaps dataset."""
|
| 73 |
+
|
| 74 |
+
VERSION = tfds.core.Version('1.0.0')
|
| 75 |
+
RELEASE_NOTES = {
|
| 76 |
+
'1.0.0': 'Initial release.',
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
def _info(self) -> tfds.core.DatasetInfo:
|
| 80 |
+
"""Returns the dataset metadata.
|
| 81 |
+
|
| 82 |
+
(tfds.core.DatasetInfo object)
|
| 83 |
+
These are the features of your dataset like images, labels, etc.
|
| 84 |
+
"""
|
| 85 |
+
|
| 86 |
+
return tfds.core.DatasetInfo(
|
| 87 |
+
builder=self,
|
| 88 |
+
description=_DESCRIPTION,
|
| 89 |
+
features=tfds.features.FeaturesDict({
|
| 90 |
+
'image/id': tf.int64,
|
| 91 |
+
'image_filepath': tfds.features.Text(),
|
| 92 |
+
'url': tfds.features.Text(),
|
| 93 |
+
'image': tfds.features.Image(encoding_format='jpeg'),
|
| 94 |
+
'texts': tfds.features.Sequence(tfds.features.Text()),
|
| 95 |
+
}),
|
| 96 |
+
# If there's a common (input, target) tuple from the
|
| 97 |
+
# features, specify them here. They'll be used if
|
| 98 |
+
# `as_supervised=True` in `builder.as_dataset`.
|
| 99 |
+
supervised_keys=None, # Set to `None` to disable
|
| 100 |
+
homepage='https://nocaps.org/',
|
| 101 |
+
citation=_CITATION,
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
def _split_generators(self, dl_manager: tfds.download.DownloadManager):
|
| 105 |
+
"""Returns SplitGenerators."""
|
| 106 |
+
def group_by_id(data, image_dir):
|
| 107 |
+
id2caps = collections.defaultdict(list)
|
| 108 |
+
for ex in data.get('annotations', []):
|
| 109 |
+
id2caps[ex['image_id']].append(ex['caption'])
|
| 110 |
+
|
| 111 |
+
id_to_example = {}
|
| 112 |
+
for ex in data['images']:
|
| 113 |
+
id_to_example[ex['id']] = {
|
| 114 |
+
'image/id': ex['id'],
|
| 115 |
+
'image_filepath': os.path.join(
|
| 116 |
+
_FILEPATH, image_dir, ex['file_name']),
|
| 117 |
+
'url': ex['coco_url'],
|
| 118 |
+
'image': os.path.join(_FILEPATH, image_dir, ex['file_name']),
|
| 119 |
+
'texts': id2caps[ex['id']] if ex['id'] in id2caps else ['N/A'],
|
| 120 |
+
}
|
| 121 |
+
return id_to_example
|
| 122 |
+
|
| 123 |
+
# Returns the Dict[split names, Iterator[Key, Example]]
|
| 124 |
+
with open(_VAL_FILES) as f:
|
| 125 |
+
val_data = group_by_id(json.load(f), 'validation')
|
| 126 |
+
with open(_TEST_FILES) as f:
|
| 127 |
+
test_data = group_by_id(json.load(f), 'test')
|
| 128 |
+
return {
|
| 129 |
+
'val': self._generate_examples(val_data),
|
| 130 |
+
'test': self._generate_examples(test_data),
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
def _generate_examples(self, data):
|
| 134 |
+
"""Generate a tf.Example object.
|
| 135 |
+
|
| 136 |
+
This contains the image, objects, attributes, regions and relationships.
|
| 137 |
+
|
| 138 |
+
Args:
|
| 139 |
+
data: a dictionary with the image/id.
|
| 140 |
+
|
| 141 |
+
Yields:
|
| 142 |
+
(key, example) tuples from dataset. The example has format specified in
|
| 143 |
+
the above DatasetInfo.
|
| 144 |
+
"""
|
| 145 |
+
for k, v in data.items():
|
| 146 |
+
try:
|
| 147 |
+
# Jpeg decode test to check early errors. The decoded images are not
|
| 148 |
+
# used, instead we rely on the default tfds.features.Image function.
|
| 149 |
+
unused_image = tf.io.read_file(v['image_filepath'])
|
| 150 |
+
unused_image = np.array(tf.image.decode_jpeg(unused_image))
|
| 151 |
+
except tf.errors.InvalidArgumentError:
|
| 152 |
+
# Unable to read image, skip this image and output download link.
|
| 153 |
+
logging.error('Unable to decode: curl -O %s', v['url'])
|
| 154 |
+
continue
|
| 155 |
+
except tf.errors.NotFoundError:
|
| 156 |
+
# Unable to read image, skip this image and output download link.
|
| 157 |
+
logging.error('File not found: curl -O %s', v['url'])
|
| 158 |
+
continue
|
| 159 |
+
|
| 160 |
+
yield k, v
|
Tipsomaly/model/big_vision/datasets/refcoco/refcoco.py
ADDED
|
@@ -0,0 +1,448 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# pylint: disable=line-too-long
|
| 16 |
+
r"""Unbatch RefCOCO, RefCOCO+, RefCOCOg datasets in TFDS structure."""
|
| 17 |
+
|
| 18 |
+
# Based on tensorflow_datasets/datasets/ref_coco
|
| 19 |
+
|
| 20 |
+
import io
|
| 21 |
+
import os
|
| 22 |
+
import pickle
|
| 23 |
+
|
| 24 |
+
import numpy as np
|
| 25 |
+
import PIL.Image
|
| 26 |
+
import pycocotools.coco
|
| 27 |
+
import tensorflow_datasets as tfds
|
| 28 |
+
|
| 29 |
+
_ROOT_PATH = '/tmp/data/'
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class RefCocoConfig(tfds.core.BuilderConfig):
|
| 33 |
+
"""Config to specify each RefCoco variant."""
|
| 34 |
+
|
| 35 |
+
def __init__(self, dataset, dataset_partition, **kwargs):
|
| 36 |
+
name = f'{dataset}_{dataset_partition}'
|
| 37 |
+
super(RefCocoConfig, self).__init__(name=name, **kwargs)
|
| 38 |
+
self.dataset = dataset
|
| 39 |
+
self.dataset_partition = dataset_partition
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
_DESCRIPTION = """RefCOCO, RefCOCO+, RefCOCOg datasets.
|
| 43 |
+
|
| 44 |
+
Images, boxes and segmentations are from the original COCO dataset
|
| 45 |
+
(Lin et al, ECCV 2014). The referential segmentations are from two different
|
| 46 |
+
sources:
|
| 47 |
+
|
| 48 |
+
1) RefCOCOg (Mao et al, CVPR 2016):
|
| 49 |
+
- https://github.com/mjhucla/Google_Refexp_toolbox
|
| 50 |
+
- This is the split used in the "refcocog_google" dataset. Note that this
|
| 51 |
+
split has overlapping images in train/validation. The same split is also
|
| 52 |
+
provided in 2).
|
| 53 |
+
|
| 54 |
+
2) Source of RefCOCO and RefCOCO+ (Yu et al, ECCV 2016):
|
| 55 |
+
- https://github.com/lichengunc/refer
|
| 56 |
+
- Apache License 2.0
|
| 57 |
+
- Provides all the splits used for generation of these datasets, including the
|
| 58 |
+
"refcocog_google" split that is identical with the split from 1).
|
| 59 |
+
|
| 60 |
+
For convenience, we provide an additional dataset "refcocox_combined" that
|
| 61 |
+
combines the datasets "refcoco_unc", "refcocoplus_unc", and "refcocog_umd",
|
| 62 |
+
unifying "testA" and "testB" into a single "test" split, and removing any images
|
| 63 |
+
from "train" that appear either in "validation" or "test".
|
| 64 |
+
|
| 65 |
+
Also for convenience, every split is unrolled twice (at the "objects" level and
|
| 66 |
+
at the "object/refs" level) and saved as "{split}_flat".
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
# pylint: disable=line-too-long
|
| 70 |
+
_CITATION = r"""
|
| 71 |
+
@inproceedings{DBLP:conf/cvpr/MaoHTCY016,
|
| 72 |
+
author = {Junhua Mao and
|
| 73 |
+
Jonathan Huang and
|
| 74 |
+
Alexander Toshev and
|
| 75 |
+
Oana Camburu and
|
| 76 |
+
Alan L. Yuille and
|
| 77 |
+
Kevin Murphy},
|
| 78 |
+
title = {Generation and Comprehension of Unambiguous Object Descriptions},
|
| 79 |
+
booktitle = {2016 {IEEE} Conference on Computer Vision and Pattern Recognition,
|
| 80 |
+
{CVPR} 2016, Las Vegas, NV, USA, June 27-30, 2016},
|
| 81 |
+
pages = {11--20},
|
| 82 |
+
publisher = {{IEEE} Computer Society},
|
| 83 |
+
year = {2016},
|
| 84 |
+
url = {https://doi.org/10.1109/CVPR.2016.9},
|
| 85 |
+
doi = {10.1109/CVPR.2016.9},
|
| 86 |
+
timestamp = {Fri, 24 Mar 2023 00:02:52 +0100},
|
| 87 |
+
biburl = {https://dblp.org/rec/conf/cvpr/MaoHTCY016.bib},
|
| 88 |
+
bibsource = {dblp computer science bibliography, https://dblp.org}
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
@inproceedings{DBLP:conf/eccv/YuPYBB16,
|
| 92 |
+
author = {Licheng Yu and
|
| 93 |
+
Patrick Poirson and
|
| 94 |
+
Shan Yang and
|
| 95 |
+
Alexander C. Berg and
|
| 96 |
+
Tamara L. Berg},
|
| 97 |
+
editor = {Bastian Leibe and
|
| 98 |
+
Jiri Matas and
|
| 99 |
+
Nicu Sebe and
|
| 100 |
+
Max Welling},
|
| 101 |
+
title = {Modeling Context in Referring Expressions},
|
| 102 |
+
booktitle = {Computer Vision - {ECCV} 2016 - 14th European Conference, Amsterdam,
|
| 103 |
+
The Netherlands, October 11-14, 2016, Proceedings, Part {II}},
|
| 104 |
+
series = {Lecture Notes in Computer Science},
|
| 105 |
+
volume = {9906},
|
| 106 |
+
pages = {69--85},
|
| 107 |
+
publisher = {Springer},
|
| 108 |
+
year = {2016},
|
| 109 |
+
url = {https://doi.org/10.1007/978-3-319-46475-6\_5},
|
| 110 |
+
doi = {10.1007/978-3-319-46475-6\_5},
|
| 111 |
+
timestamp = {Wed, 07 Dec 2022 23:10:23 +0100},
|
| 112 |
+
biburl = {https://dblp.org/rec/conf/eccv/YuPYBB16.bib},
|
| 113 |
+
bibsource = {dblp computer science bibliography, https://dblp.org}
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
@article{DBLP:journals/corr/LinMBHPRDZ14,
|
| 117 |
+
author = {Tsung{-}Yi Lin and
|
| 118 |
+
Michael Maire and
|
| 119 |
+
Serge J. Belongie and
|
| 120 |
+
Lubomir D. Bourdev and
|
| 121 |
+
Ross B. Girshick and
|
| 122 |
+
James Hays and
|
| 123 |
+
Pietro Perona and
|
| 124 |
+
Deva Ramanan and
|
| 125 |
+
Piotr Doll{\'{a}}r and
|
| 126 |
+
C. Lawrence Zitnick},
|
| 127 |
+
title = {Microsoft {COCO:} Common Objects in Context},
|
| 128 |
+
journal = {CoRR},
|
| 129 |
+
volume = {abs/1405.0312},
|
| 130 |
+
year = {2014},
|
| 131 |
+
url = {http://arxiv.org/abs/1405.0312},
|
| 132 |
+
archivePrefix = {arXiv},
|
| 133 |
+
eprint = {1405.0312},
|
| 134 |
+
timestamp = {Mon, 13 Aug 2018 16:48:13 +0200},
|
| 135 |
+
biburl = {https://dblp.org/rec/bib/journals/corr/LinMBHPRDZ14},
|
| 136 |
+
bibsource = {dblp computer science bibliography, https://dblp.org}
|
| 137 |
+
}
|
| 138 |
+
"""
|
| 139 |
+
|
| 140 |
+
# coco_data = json.load(open('annotations/instances_train2017.json'))
|
| 141 |
+
# [l['name'] for l in coco_data['licenses']]
|
| 142 |
+
LICENSES = [
|
| 143 |
+
'Attribution-NonCommercial-ShareAlike License',
|
| 144 |
+
'Attribution-NonCommercial License',
|
| 145 |
+
'Attribution-NonCommercial-NoDerivs License',
|
| 146 |
+
'Attribution License',
|
| 147 |
+
'Attribution-ShareAlike License',
|
| 148 |
+
'Attribution-NoDerivs License',
|
| 149 |
+
'No known copyright restrictions',
|
| 150 |
+
'United States Government Work',
|
| 151 |
+
]
|
| 152 |
+
# _licenses_map = {l['id']: i for i, l in enumerate(coco_data['licenses'])}
|
| 153 |
+
_licenses_map = {1: 0, 2: 1, 3: 2, 4: 3, 5: 4, 6: 5, 7: 6, 8: 7}
|
| 154 |
+
|
| 155 |
+
# pyformat: disable
|
| 156 |
+
# [c['name'] for c in coco_data['categories']]
|
| 157 |
+
CATEGORIES = [
|
| 158 |
+
'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train',
|
| 159 |
+
'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign',
|
| 160 |
+
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
|
| 161 |
+
'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag',
|
| 162 |
+
'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite',
|
| 163 |
+
'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
|
| 164 |
+
'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon',
|
| 165 |
+
'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot',
|
| 166 |
+
'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant',
|
| 167 |
+
'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
|
| 168 |
+
'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
|
| 169 |
+
'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
|
| 170 |
+
'hair drier', 'toothbrush',
|
| 171 |
+
]
|
| 172 |
+
# sorted(set(c['supercategory'] for c in coco_data['categories']))
|
| 173 |
+
SUPERCATEGORIES = [
|
| 174 |
+
'accessory', 'animal', 'appliance', 'electronic', 'food', 'furniture',
|
| 175 |
+
'indoor', 'kitchen', 'outdoor', 'person', 'sports', 'vehicle',
|
| 176 |
+
]
|
| 177 |
+
# pyformat: enable
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
# Will be exported into directory `$TFDS_DATA_DIR/ref_coco_bv`
|
| 181 |
+
# If the class name was `RefCOCO` then it would be exported into
|
| 182 |
+
# `$TFDS_DATA_DIR/ref_coco`, which would collide with the default TFDS dataset
|
| 183 |
+
# also named `ref_coco` (which has precedence over `data_dir` builder arg).
|
| 184 |
+
class RefCocoBv(tfds.core.GeneratorBasedBuilder):
|
| 185 |
+
"""DatasetBuilder for RefCoco datasets."""
|
| 186 |
+
|
| 187 |
+
VERSION = tfds.core.Version('1.4.0')
|
| 188 |
+
RELEASE_NOTES = {
|
| 189 |
+
'1.4.0': 'Added flat versions of all dataset splits.',
|
| 190 |
+
'1.3.0': 'Added "refcocox_combined" dataset.',
|
| 191 |
+
'1.2.0': 'Added "train_flat" splits.',
|
| 192 |
+
'1.1.0': 'Added more features (mask etc), nested "refs" in "objects".',
|
| 193 |
+
'1.0.0': 'Initial release.',
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
MANUAL_DOWNLOAD_INSTRUCTIONS = """
|
| 197 |
+
1. Install https://pypi.org/project/pycocotools/.
|
| 198 |
+
|
| 199 |
+
2. Download data (requires ~20G for COCO images):
|
| 200 |
+
|
| 201 |
+
(mkdir -p /tmp/tfds/downloads/manual &&
|
| 202 |
+
cd /tmp/tfds/downloads/manual &&
|
| 203 |
+
wget http://images.cocodataset.org/zips/train2017.zip &&
|
| 204 |
+
wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip &&
|
| 205 |
+
wget https://bvisionweb1.cs.unc.edu/licheng/referit/data/refcoco.zip &&
|
| 206 |
+
wget https://bvisionweb1.cs.unc.edu/licheng/referit/data/refcoco+.zip &&
|
| 207 |
+
wget https://bvisionweb1.cs.unc.edu/licheng/referit/data/refcocog.zip &&
|
| 208 |
+
for zip in *.zip; do unzip $zip; done
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
3. Run the generation script with `TFDS_DATA_DIR=/tmp/tfds`
|
| 212 |
+
"""
|
| 213 |
+
|
| 214 |
+
BUILDER_CONFIGS = [
|
| 215 |
+
RefCocoConfig(dataset='refcoco', dataset_partition='unc'),
|
| 216 |
+
RefCocoConfig(dataset='refcoco', dataset_partition='google'),
|
| 217 |
+
RefCocoConfig(dataset='refcocoplus', dataset_partition='unc'),
|
| 218 |
+
RefCocoConfig(dataset='refcocog', dataset_partition='google'),
|
| 219 |
+
RefCocoConfig(dataset='refcocog', dataset_partition='umd'),
|
| 220 |
+
RefCocoConfig(dataset='refcocox', dataset_partition='combined'),
|
| 221 |
+
]
|
| 222 |
+
|
| 223 |
+
def _info(self) -> tfds.core.DatasetInfo:
|
| 224 |
+
return tfds.core.DatasetInfo(
|
| 225 |
+
builder=self,
|
| 226 |
+
features=tfds.features.FeaturesDict({
|
| 227 |
+
'id': tfds.features.Scalar(np.int32),
|
| 228 |
+
'image': tfds.features.Image(encoding_format='jpeg'),
|
| 229 |
+
'height': tfds.features.Scalar(np.int32),
|
| 230 |
+
'width': tfds.features.Scalar(np.int32),
|
| 231 |
+
'license': tfds.features.ClassLabel(names=LICENSES),
|
| 232 |
+
'file_name': tfds.features.Text(),
|
| 233 |
+
'flickr_url': tfds.features.Text(),
|
| 234 |
+
'coco_url': tfds.features.Text(),
|
| 235 |
+
'objects': tfds.features.Sequence({
|
| 236 |
+
'id': tfds.features.Scalar(np.int64),
|
| 237 |
+
'area': tfds.features.Scalar(np.float32),
|
| 238 |
+
'bbox': tfds.features.BBoxFeature(),
|
| 239 |
+
'mask': tfds.features.Image(encoding_format='png'),
|
| 240 |
+
'category': tfds.features.ClassLabel(names=CATEGORIES),
|
| 241 |
+
'supercategory': tfds.features.ClassLabel(
|
| 242 |
+
names=SUPERCATEGORIES
|
| 243 |
+
),
|
| 244 |
+
'iscrowd': tfds.features.Scalar(np.bool_),
|
| 245 |
+
# refcoco, refcoco+, refcocog features:
|
| 246 |
+
'refs': tfds.features.Sequence({
|
| 247 |
+
'id': tfds.features.Scalar(np.int32),
|
| 248 |
+
'sentence': tfds.features.Text(),
|
| 249 |
+
}),
|
| 250 |
+
}),
|
| 251 |
+
}),
|
| 252 |
+
supervised_keys=None, # Set to `None` to disable
|
| 253 |
+
citation=_CITATION,
|
| 254 |
+
description=_DESCRIPTION,
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
def _split_generators(self, dl_manager: tfds.download.DownloadManager):
|
| 258 |
+
allowed_splits = {
|
| 259 |
+
('refcoco', 'google'): [
|
| 260 |
+
tfds.Split.TRAIN,
|
| 261 |
+
tfds.Split.VALIDATION,
|
| 262 |
+
tfds.Split.TEST,
|
| 263 |
+
],
|
| 264 |
+
('refcoco', 'unc'): [
|
| 265 |
+
tfds.Split.TRAIN,
|
| 266 |
+
tfds.Split.VALIDATION,
|
| 267 |
+
'testA',
|
| 268 |
+
'testB',
|
| 269 |
+
],
|
| 270 |
+
('refcocoplus', 'unc'): [
|
| 271 |
+
tfds.Split.TRAIN,
|
| 272 |
+
tfds.Split.VALIDATION,
|
| 273 |
+
'testA',
|
| 274 |
+
'testB',
|
| 275 |
+
],
|
| 276 |
+
# Verified manually that image and annotation IDs match the ones in
|
| 277 |
+
# https://storage.googleapis.com/refexp/google_refexp_dataset_release.zip
|
| 278 |
+
('refcocog', 'google'): [
|
| 279 |
+
tfds.Split.TRAIN,
|
| 280 |
+
tfds.Split.VALIDATION,
|
| 281 |
+
],
|
| 282 |
+
('refcocog', 'umd'): [
|
| 283 |
+
tfds.Split.TRAIN,
|
| 284 |
+
tfds.Split.VALIDATION,
|
| 285 |
+
tfds.Split.TEST,
|
| 286 |
+
],
|
| 287 |
+
('refcocox', 'combined'): [
|
| 288 |
+
tfds.Split.TRAIN,
|
| 289 |
+
tfds.Split.VALIDATION,
|
| 290 |
+
tfds.Split.TEST,
|
| 291 |
+
],
|
| 292 |
+
}
|
| 293 |
+
bc = self.builder_config
|
| 294 |
+
splits = allowed_splits[(bc.dataset, bc.dataset_partition)]
|
| 295 |
+
|
| 296 |
+
data_dir = dl_manager.manual_dir
|
| 297 |
+
for url, components in (
|
| 298 |
+
# pylint: disable=line-too-long
|
| 299 |
+
# pyformat: disable
|
| 300 |
+
('http://images.cocodataset.org/zips/train2017.zip', ('train2017', '000000147328.jpg')),
|
| 301 |
+
('http://images.cocodataset.org/annotations/annotations_trainval2017.zip', ('annotations', 'instances_train2017.json')),
|
| 302 |
+
('https://bvisionweb1.cs.unc.edu/licheng/referit/data/refcoco.zip', ('refcoco', 'refs(unc).p')),
|
| 303 |
+
('https://bvisionweb1.cs.unc.edu/licheng/referit/data/refcoco+.zip', ('refcoco+', 'refs(unc).p')),
|
| 304 |
+
('https://bvisionweb1.cs.unc.edu/licheng/referit/data/refcocog.zip', ('refcocog', 'refs(umd).p')),
|
| 305 |
+
# pyformat: enable
|
| 306 |
+
# pylint: enable=line-too-long
|
| 307 |
+
):
|
| 308 |
+
path = os.path.exists(os.path.join(data_dir, *components))
|
| 309 |
+
if not path:
|
| 310 |
+
raise FileNotFoundError(
|
| 311 |
+
f'Could not find {path}: please download {url} and unzip into'
|
| 312 |
+
f' {data_dir}'
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
coco = pycocotools.coco.COCO(
|
| 316 |
+
os.path.join(data_dir, 'annotations', 'instances_train2017.json')
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
return {
|
| 320 |
+
split + suffix: self._generate_examples(
|
| 321 |
+
coco, data_dir, bc.dataset, bc.dataset_partition, split + suffix,
|
| 322 |
+
)
|
| 323 |
+
for split in splits
|
| 324 |
+
for suffix in ('', '_flat')
|
| 325 |
+
}
|
| 326 |
+
|
| 327 |
+
# Builder must overwrite all abstract methods.
|
| 328 |
+
def _generate_examples(
|
| 329 |
+
self, coco, data_dir, dataset, dataset_partition, split):
|
| 330 |
+
return _generate_examples(coco, data_dir, dataset, dataset_partition, split)
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
def _get_ids(data_dir, dataset, dataset_partition, split):
|
| 334 |
+
"""Returns `img_ids, ann_to_refs` for specified dataset/partition/split."""
|
| 335 |
+
|
| 336 |
+
def load(dataset, dataset_partition):
|
| 337 |
+
fname = f'refs({dataset_partition}).p'
|
| 338 |
+
path = os.path.join(data_dir, dataset, fname)
|
| 339 |
+
refcoco = pickle.load(open(path, 'rb'))
|
| 340 |
+
return refcoco
|
| 341 |
+
|
| 342 |
+
if split == tfds.Split.VALIDATION:
|
| 343 |
+
split = 'val'
|
| 344 |
+
|
| 345 |
+
if (dataset, dataset_partition) == ('refcocox', 'combined'):
|
| 346 |
+
refcoco = (
|
| 347 |
+
load('refcocog', 'umd')
|
| 348 |
+
+ load('refcoco', 'unc')
|
| 349 |
+
+ load('refcoco+', 'unc')
|
| 350 |
+
)
|
| 351 |
+
if split == 'test':
|
| 352 |
+
splits = ('test', 'testA', 'testB')
|
| 353 |
+
else:
|
| 354 |
+
splits = (split,)
|
| 355 |
+
|
| 356 |
+
exclude_img_ids = set()
|
| 357 |
+
if split == 'train':
|
| 358 |
+
# Exclude all images with val/test annotations from train set.
|
| 359 |
+
exclude_img_ids = {
|
| 360 |
+
r['image_id'] for r in refcoco if r['split'] != 'train'
|
| 361 |
+
}
|
| 362 |
+
refcoco = [
|
| 363 |
+
r
|
| 364 |
+
for r in refcoco
|
| 365 |
+
if r['split'] in splits and r['image_id'] not in exclude_img_ids
|
| 366 |
+
]
|
| 367 |
+
|
| 368 |
+
else:
|
| 369 |
+
if dataset == 'refcocoplus':
|
| 370 |
+
dataset = 'refcoco+'
|
| 371 |
+
refcoco = load(dataset, dataset_partition)
|
| 372 |
+
refcoco = [r for r in refcoco if r['split'] == split]
|
| 373 |
+
|
| 374 |
+
img_ids = {r['image_id'] for r in refcoco}
|
| 375 |
+
ann_to_refs = {}
|
| 376 |
+
for r in refcoco:
|
| 377 |
+
for sent in r['sentences']:
|
| 378 |
+
ann_to_refs.setdefault(r['ann_id'], []).append(dict(
|
| 379 |
+
id=sent['sent_id'],
|
| 380 |
+
sentence=sent['sent']
|
| 381 |
+
))
|
| 382 |
+
|
| 383 |
+
return img_ids, ann_to_refs
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
def _generate_examples(coco, data_dir, dataset, dataset_partition, split):
|
| 387 |
+
"""Generates examples for a given split."""
|
| 388 |
+
|
| 389 |
+
flat = '_flat' in split
|
| 390 |
+
split = split.replace('_flat', '')
|
| 391 |
+
img_ids, ann_to_refs = _get_ids(data_dir, dataset, dataset_partition, split)
|
| 392 |
+
|
| 393 |
+
for img_id in coco.getImgIds():
|
| 394 |
+
|
| 395 |
+
if img_id not in img_ids:
|
| 396 |
+
continue
|
| 397 |
+
img, = coco.loadImgs([img_id])
|
| 398 |
+
|
| 399 |
+
example = {
|
| 400 |
+
'id': img_id,
|
| 401 |
+
'image': os.path.join(data_dir, 'train2017', img['file_name']),
|
| 402 |
+
'height': img['height'],
|
| 403 |
+
'width': img['width'],
|
| 404 |
+
'license': LICENSES[_licenses_map[img['license']]],
|
| 405 |
+
'file_name': img['file_name'],
|
| 406 |
+
'flickr_url': img['flickr_url'],
|
| 407 |
+
'coco_url': img['coco_url'],
|
| 408 |
+
'objects': [],
|
| 409 |
+
}
|
| 410 |
+
for ann in coco.loadAnns(coco.getAnnIds(img_id)):
|
| 411 |
+
refs = ann_to_refs.get(ann['id'])
|
| 412 |
+
if not refs:
|
| 413 |
+
continue
|
| 414 |
+
cat, = coco.loadCats([ann['category_id']])
|
| 415 |
+
mask = coco.annToMask(ann).astype(np.bool_)
|
| 416 |
+
mask_buf = io.BytesIO()
|
| 417 |
+
PIL.Image.fromarray(mask).save(mask_buf, 'png')
|
| 418 |
+
mask_buf.seek(0)
|
| 419 |
+
object_ = {
|
| 420 |
+
'id': ann['id'],
|
| 421 |
+
'mask': mask_buf,
|
| 422 |
+
'category': cat['name'],
|
| 423 |
+
'supercategory': cat['supercategory'],
|
| 424 |
+
'iscrowd': ann['iscrowd'],
|
| 425 |
+
'area': ann['area'],
|
| 426 |
+
'bbox': _convert_bbox(img, *ann['bbox']),
|
| 427 |
+
'refs': refs,
|
| 428 |
+
}
|
| 429 |
+
if flat:
|
| 430 |
+
example['objects'] = [object_]
|
| 431 |
+
for ref_i, ref in enumerate(refs):
|
| 432 |
+
object_['refs'] = [ref]
|
| 433 |
+
mask_buf.seek(0)
|
| 434 |
+
yield f'{img_id}_{ann["id"]}_{ref_i}', example
|
| 435 |
+
else:
|
| 436 |
+
example['objects'].append(object_)
|
| 437 |
+
|
| 438 |
+
if not flat:
|
| 439 |
+
yield img_id, example
|
| 440 |
+
|
| 441 |
+
|
| 442 |
+
def _convert_bbox(img, x, y, w, h):
|
| 443 |
+
return tfds.features.BBox(
|
| 444 |
+
ymin=y / img['height'],
|
| 445 |
+
xmin=x / img['width'],
|
| 446 |
+
ymax=(y + h) / img['height'],
|
| 447 |
+
xmax=(x + w) / img['width'],
|
| 448 |
+
)
|
Tipsomaly/model/big_vision/datasets/rsvqa_hr/rsvqa_hr.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# pylint: disable=line-too-long
|
| 16 |
+
r"""Implements RSVQA-HR dataset in TFDS.
|
| 17 |
+
|
| 18 |
+
Remote sensing visual question answering task, using high-resolution airborne
|
| 19 |
+
image data at 15cm resolution per pixel.
|
| 20 |
+
|
| 21 |
+
It's small dataset at source (14G), so simple to run locally.
|
| 22 |
+
First, download and unzip the dataset from https://zenodo.org/records/6344367
|
| 23 |
+
and place it in /tmp/data/rsvqa_hr.
|
| 24 |
+
|
| 25 |
+
Then, run conversion locally (make sure to install tensorflow-datasets for the `tfds` util):
|
| 26 |
+
|
| 27 |
+
cd third_party/py/big_vision/datasets
|
| 28 |
+
env TFDS_DATA_DIR=/tmp/tfds tfds build --datasets=rsvqa_hr
|
| 29 |
+
|
| 30 |
+
Example to load:
|
| 31 |
+
|
| 32 |
+
import tensorflow_datasets as tfds
|
| 33 |
+
dataset = tfds.load('rsvqa_hr', split='train', data_dir='/tmp/tfds')
|
| 34 |
+
|
| 35 |
+
Dataset splits (all):
|
| 36 |
+
train: 625,340 examples/questions
|
| 37 |
+
val: 102,843 examples/questions
|
| 38 |
+
test: 222,684 examples/questions
|
| 39 |
+
test_2: 105,647 examples/questions (other area, unknown instrument)
|
| 40 |
+
Non-numeric data splits (nonum):
|
| 41 |
+
train: 371,834 examples/questions
|
| 42 |
+
val: 60,405 examples/questions
|
| 43 |
+
test: 131,468 examples/questions
|
| 44 |
+
test_2: 62,554 examples/questions
|
| 45 |
+
|
| 46 |
+
Note: due to image duplication with each question, the dataset size is
|
| 47 |
+
significatnly increased by the number of questions per image.
|
| 48 |
+
|
| 49 |
+
Recommended training splits:
|
| 50 |
+
train: train
|
| 51 |
+
minitrain: train[:5%]
|
| 52 |
+
eval: val
|
| 53 |
+
full_train: train+val
|
| 54 |
+
test: test
|
| 55 |
+
|
| 56 |
+
Image sizes: 512x512
|
| 57 |
+
Number of answers per question: 1
|
| 58 |
+
Question types distribution in train split:
|
| 59 |
+
- Area (area): 14.6% (integers, binned into {0m2, 1-10m2, 11-100m2, 101-1000m2, >1000m2})
|
| 60 |
+
- Comparison(comp): 33.5%
|
| 61 |
+
- Count (count): 26.0% (integers, not binned, maximum number of objects is 89)
|
| 62 |
+
- Presence (presence): 26.0%
|
| 63 |
+
"""
|
| 64 |
+
import json
|
| 65 |
+
import os
|
| 66 |
+
|
| 67 |
+
import numpy as np
|
| 68 |
+
import tensorflow_datasets as tfds
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
_DESCRIPTION = """RSVQA-HR dataset."""
|
| 72 |
+
|
| 73 |
+
# pylint: disable=line-too-long
|
| 74 |
+
_CITATION = """
|
| 75 |
+
@article{Lobry_2020,
|
| 76 |
+
title={RSVQA: Visual Question Answering for Remote Sensing Data},
|
| 77 |
+
volume={58},
|
| 78 |
+
ISSN={1558-0644},
|
| 79 |
+
url={http://dx.doi.org/10.1109/TGRS.2020.2988782},
|
| 80 |
+
DOI={10.1109/tgrs.2020.2988782},
|
| 81 |
+
number={12},
|
| 82 |
+
journal={IEEE Transactions on Geoscience and Remote Sensing},
|
| 83 |
+
publisher={Institute of Electrical and Electronics Engineers (IEEE)},
|
| 84 |
+
author={Lobry, Sylvain and Marcos, Diego and Murray, Jesse and Tuia, Devis},
|
| 85 |
+
year={2020},
|
| 86 |
+
month=dec, pages={8555-8566} }
|
| 87 |
+
"""
|
| 88 |
+
# pylint: enable=line-too-long
|
| 89 |
+
|
| 90 |
+
# When running locally (recommended), copy files as above an use these:
|
| 91 |
+
PATH = '/tmp/data/rsvqa_hr/'
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class RsvqaHrConfig(tfds.core.BuilderConfig):
|
| 95 |
+
"""Config to specify each variant."""
|
| 96 |
+
|
| 97 |
+
def __init__(self, nonum, **kwargs):
|
| 98 |
+
name = 'nonum' if nonum else 'all'
|
| 99 |
+
super(RsvqaHrConfig, self).__init__(name=name, **kwargs)
|
| 100 |
+
self.nonum = nonum
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class RsvqaHr(tfds.core.GeneratorBasedBuilder):
|
| 104 |
+
"""DatasetBuilder for RSVQA-HR dataset."""
|
| 105 |
+
|
| 106 |
+
VERSION = tfds.core.Version('1.0.2')
|
| 107 |
+
RELEASE_NOTES = {
|
| 108 |
+
'1.0.0': 'First release.',
|
| 109 |
+
'1.0.1': 'Rename binned values.',
|
| 110 |
+
'1.0.2': 'Removed explicit png image encoding.',
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
BUILDER_CONFIGS = [
|
| 114 |
+
RsvqaHrConfig(nonum=False),
|
| 115 |
+
RsvqaHrConfig(nonum=True),
|
| 116 |
+
]
|
| 117 |
+
|
| 118 |
+
def _info(self):
|
| 119 |
+
"""Returns the metadata."""
|
| 120 |
+
|
| 121 |
+
return tfds.core.DatasetInfo(
|
| 122 |
+
builder=self,
|
| 123 |
+
description=_DESCRIPTION,
|
| 124 |
+
features=tfds.features.FeaturesDict({
|
| 125 |
+
'question_id': tfds.features.Scalar(np.int32),
|
| 126 |
+
'filename': tfds.features.Text(),
|
| 127 |
+
'image': tfds.features.Image(),
|
| 128 |
+
'question': tfds.features.Text(),
|
| 129 |
+
'question_type': tfds.features.Text(),
|
| 130 |
+
'answers': tfds.features.Sequence(tfds.features.Text()),
|
| 131 |
+
'raw_answers': tfds.features.Sequence(tfds.features.Text()),
|
| 132 |
+
}),
|
| 133 |
+
supervised_keys=None,
|
| 134 |
+
homepage='https://rsvqa.sylvainlobry.com/',
|
| 135 |
+
citation=_CITATION,
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
def _split_generators(self, dl_manager: tfds.download.DownloadManager):
|
| 139 |
+
"""Returns SplitGenerators."""
|
| 140 |
+
return {
|
| 141 |
+
split: self._generate_examples(split)
|
| 142 |
+
for split in ('train', 'val', 'test', 'test_2')
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
def _generate_examples(self, split):
|
| 146 |
+
"""Yields (key, example) tuples."""
|
| 147 |
+
if split == 'test_2':
|
| 148 |
+
split = 'test_phili'
|
| 149 |
+
questions_path = os.path.join(PATH + f'USGS_split_{split}_questions.json')
|
| 150 |
+
answers_path = os.path.join(PATH + f'USGS_split_{split}_answers.json')
|
| 151 |
+
images_path = os.path.join(PATH + 'Data')
|
| 152 |
+
|
| 153 |
+
with open(questions_path, 'r') as f:
|
| 154 |
+
questions = json.loads(f.read())['questions']
|
| 155 |
+
with open(answers_path, 'r') as f:
|
| 156 |
+
answers = json.loads(f.read())['answers']
|
| 157 |
+
|
| 158 |
+
for q, a in zip(questions, answers):
|
| 159 |
+
assert q['active'] == a['active']
|
| 160 |
+
if not q['active']:
|
| 161 |
+
continue
|
| 162 |
+
if self.builder_config.nonum and q['type'] in ('area', 'count'):
|
| 163 |
+
continue
|
| 164 |
+
assert q['answers_ids'][0] == a['id']
|
| 165 |
+
assert q['id'] == a['question_id']
|
| 166 |
+
|
| 167 |
+
filename = f'{q["img_id"]}.png'
|
| 168 |
+
yield q['id'], {
|
| 169 |
+
'question_id': q['id'],
|
| 170 |
+
'filename': filename,
|
| 171 |
+
'image': os.path.join(images_path, filename),
|
| 172 |
+
'question': q['question'],
|
| 173 |
+
'question_type': q['type'],
|
| 174 |
+
'answers': [bin_answer(a['answer'], q['type'])],
|
| 175 |
+
'raw_answers': [a['answer']],
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def bin_answer(answer, question_type):
|
| 180 |
+
"""Bins answers into expected ranges."""
|
| 181 |
+
if question_type == 'area':
|
| 182 |
+
area = int(answer[:-2])
|
| 183 |
+
if area == 0:
|
| 184 |
+
return '0 m2'
|
| 185 |
+
elif area <= 10:
|
| 186 |
+
return 'between 1 m2 and 10 m2'
|
| 187 |
+
elif area <= 100:
|
| 188 |
+
return 'between 11 m2 and 100 m2'
|
| 189 |
+
elif area <= 1000:
|
| 190 |
+
return 'between 101 m2 and 1000 m2'
|
| 191 |
+
else:
|
| 192 |
+
return 'more than 1000 m2'
|
| 193 |
+
return answer
|
Tipsomaly/model/big_vision/datasets/rsvqa_lr/rsvqa_lr.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# pylint: disable=line-too-long
|
| 16 |
+
r"""Implements RSVQA-LR dataset in TFDS.
|
| 17 |
+
|
| 18 |
+
Remote sensing visual question answering task, using low-resolution satellite
|
| 19 |
+
(Sentinel-2) RGB channels data at 10m resolution per pixel.
|
| 20 |
+
|
| 21 |
+
It's small dataset at source (200M), so simple to run locally.
|
| 22 |
+
First, download and unzip the dataset from https://zenodo.org/records/6344334
|
| 23 |
+
and place it in /tmp/data/rsvqa_lr.
|
| 24 |
+
|
| 25 |
+
Then, run conversion locally (make sure to install tensorflow-datasets for the `tfds` util):
|
| 26 |
+
|
| 27 |
+
cd third_party/py/big_vision/datasets
|
| 28 |
+
env TFDS_DATA_DIR=/tmp/tfds tfds build --datasets=rsvqa_lr
|
| 29 |
+
|
| 30 |
+
Example to load:
|
| 31 |
+
|
| 32 |
+
import tensorflow_datasets as tfds
|
| 33 |
+
dataset = tfds.load('rsvqa_lr', split='train', data_dir='/tmp/tfds')
|
| 34 |
+
|
| 35 |
+
Dataset splits:
|
| 36 |
+
train: 57223 examples/questions
|
| 37 |
+
val: 10005 examples/questions
|
| 38 |
+
test: 10004 examples/questions
|
| 39 |
+
And the same splits are available excluding numeric questions:
|
| 40 |
+
train_nonum: 39441 examples/questions
|
| 41 |
+
val_nonum: 6782 examples/questions
|
| 42 |
+
test_nonum: 6782 examples/questions
|
| 43 |
+
|
| 44 |
+
Note: due to image duplication with each question, the dataset size is
|
| 45 |
+
significatnly increased by the number of questions per image.
|
| 46 |
+
|
| 47 |
+
Recommended training splits:
|
| 48 |
+
train: train
|
| 49 |
+
minitrain: train[:5%]
|
| 50 |
+
eval: val
|
| 51 |
+
full_train: train+val
|
| 52 |
+
test: test
|
| 53 |
+
|
| 54 |
+
Image sizes: 256x256
|
| 55 |
+
Number of answers per question: 1
|
| 56 |
+
Question types distribution in train split:
|
| 57 |
+
- Comparison(comp): 39.4%
|
| 58 |
+
- Count (count): 29.9% (integers, binned at evaluation into
|
| 59 |
+
{0, 1-10, 11-100, 101-1000, >10000})
|
| 60 |
+
- Presence (presence): 29.7%
|
| 61 |
+
- Rural/Urban (rural_urban): 1%
|
| 62 |
+
"""
|
| 63 |
+
import io
|
| 64 |
+
import json
|
| 65 |
+
import os
|
| 66 |
+
|
| 67 |
+
import numpy as np
|
| 68 |
+
import tensorflow_datasets as tfds
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
_DESCRIPTION = """RSVQA-LR dataset."""
|
| 72 |
+
|
| 73 |
+
# pylint: disable=line-too-long
|
| 74 |
+
_CITATION = """
|
| 75 |
+
@article{Lobry_2020,
|
| 76 |
+
title={RSVQA: Visual Question Answering for Remote Sensing Data},
|
| 77 |
+
volume={58},
|
| 78 |
+
ISSN={1558-0644},
|
| 79 |
+
url={http://dx.doi.org/10.1109/TGRS.2020.2988782},
|
| 80 |
+
DOI={10.1109/tgrs.2020.2988782},
|
| 81 |
+
number={12},
|
| 82 |
+
journal={IEEE Transactions on Geoscience and Remote Sensing},
|
| 83 |
+
publisher={Institute of Electrical and Electronics Engineers (IEEE)},
|
| 84 |
+
author={Lobry, Sylvain and Marcos, Diego and Murray, Jesse and Tuia, Devis},
|
| 85 |
+
year={2020},
|
| 86 |
+
month=dec, pages={8555–8566} }
|
| 87 |
+
"""
|
| 88 |
+
# pylint: enable=line-too-long
|
| 89 |
+
|
| 90 |
+
# When running locally (recommended), copy files as above an use these:
|
| 91 |
+
PATH = '/tmp/data/rsvqa_lr/'
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class RsvqaLrConfig(tfds.core.BuilderConfig):
|
| 95 |
+
"""Config to specify each variant."""
|
| 96 |
+
|
| 97 |
+
def __init__(self, nonum, **kwargs):
|
| 98 |
+
name = 'nonum' if nonum else 'all'
|
| 99 |
+
super(RsvqaLrConfig, self).__init__(name=name, **kwargs)
|
| 100 |
+
self.nonum = nonum
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class RsvqaLr(tfds.core.GeneratorBasedBuilder):
|
| 104 |
+
"""DatasetBuilder for RSVQA-LR dataset."""
|
| 105 |
+
|
| 106 |
+
VERSION = tfds.core.Version('1.0.2')
|
| 107 |
+
RELEASE_NOTES = {
|
| 108 |
+
'1.0.0': 'First release.',
|
| 109 |
+
'1.0.1': 'Rename binned values.',
|
| 110 |
+
'1.0.2': 'Removed explicit png image encoding.',
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
BUILDER_CONFIGS = [
|
| 114 |
+
RsvqaLrConfig(nonum=False),
|
| 115 |
+
RsvqaLrConfig(nonum=True),
|
| 116 |
+
]
|
| 117 |
+
|
| 118 |
+
def _info(self):
|
| 119 |
+
"""Returns the metadata."""
|
| 120 |
+
|
| 121 |
+
return tfds.core.DatasetInfo(
|
| 122 |
+
builder=self,
|
| 123 |
+
description=_DESCRIPTION,
|
| 124 |
+
features=tfds.features.FeaturesDict({
|
| 125 |
+
'question_id': tfds.features.Scalar(np.int32),
|
| 126 |
+
'filename': tfds.features.Text(),
|
| 127 |
+
'image': tfds.features.Image(),
|
| 128 |
+
'question': tfds.features.Text(),
|
| 129 |
+
'question_type': tfds.features.Text(),
|
| 130 |
+
'answers': tfds.features.Sequence(tfds.features.Text()),
|
| 131 |
+
'raw_answers': tfds.features.Sequence(tfds.features.Text()),
|
| 132 |
+
}),
|
| 133 |
+
supervised_keys=None,
|
| 134 |
+
homepage='https://rsvqa.sylvainlobry.com/',
|
| 135 |
+
citation=_CITATION,
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
def _split_generators(self, dl_manager: tfds.download.DownloadManager):
|
| 139 |
+
"""Returns SplitGenerators."""
|
| 140 |
+
return {
|
| 141 |
+
split: self._generate_examples(split)
|
| 142 |
+
for split in ('train', 'val', 'test')
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
def _generate_examples(self, split):
|
| 146 |
+
"""Yields (key, example) tuples."""
|
| 147 |
+
questions_path = os.path.join(PATH + f'LR_split_{split}_questions.json')
|
| 148 |
+
answers_path = os.path.join(PATH + f'LR_split_{split}_answers.json')
|
| 149 |
+
images_path = os.path.join(PATH + 'Images_LR')
|
| 150 |
+
|
| 151 |
+
with open(questions_path, 'r') as f:
|
| 152 |
+
questions = json.loads(f.read())['questions']
|
| 153 |
+
with open(answers_path, 'r') as f:
|
| 154 |
+
answers = json.loads(f.read())['answers']
|
| 155 |
+
|
| 156 |
+
for q, a in zip(questions, answers):
|
| 157 |
+
assert q['active'] == a['active']
|
| 158 |
+
if not q['active']:
|
| 159 |
+
continue
|
| 160 |
+
if self.builder_config.nonum and q['type'] == 'count':
|
| 161 |
+
continue
|
| 162 |
+
assert q['answers_ids'] == [a['id']]
|
| 163 |
+
assert q['id'] == a['question_id']
|
| 164 |
+
|
| 165 |
+
filename = f'{q["img_id"]}.tif'
|
| 166 |
+
img = read_tif(os.path.join(images_path, filename))
|
| 167 |
+
yield q['id'], {
|
| 168 |
+
'question_id': q['id'],
|
| 169 |
+
'filename': filename,
|
| 170 |
+
'image': img,
|
| 171 |
+
'question': q['question'],
|
| 172 |
+
'question_type': q['type'],
|
| 173 |
+
'answers': [bin_answer(a['answer'], q['type'])],
|
| 174 |
+
'raw_answers': [a['answer']],
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def bin_answer(answer, question_type):
|
| 179 |
+
"""Bins answers into expected ranges."""
|
| 180 |
+
if question_type == 'count':
|
| 181 |
+
count = int(answer)
|
| 182 |
+
if count == 0:
|
| 183 |
+
return '0'
|
| 184 |
+
elif count <= 10:
|
| 185 |
+
return 'between 1 and 10'
|
| 186 |
+
elif count <= 100:
|
| 187 |
+
return 'between 11 and 100'
|
| 188 |
+
elif count <= 1000:
|
| 189 |
+
return 'between 101 and 1000'
|
| 190 |
+
else:
|
| 191 |
+
return 'more than 1000'
|
| 192 |
+
return answer
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def read_tif(path):
|
| 196 |
+
with open(path, 'rb') as f:
|
| 197 |
+
img = tfds.core.lazy_imports.tifffile.imread(io.BytesIO(f.read()))
|
| 198 |
+
return img.astype(np.uint8)
|
Tipsomaly/model/big_vision/datasets/scicap/scicap.py
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# pylint: disable=line-too-long
|
| 16 |
+
r"""Creates TFDS dataset for SciCap.
|
| 17 |
+
|
| 18 |
+
Preparing the data:
|
| 19 |
+
1) mkdir /tmp/data/scicap && cd /tmp/data/scicap
|
| 20 |
+
2) wget 'https://www.dropbox.com/s/t1sjqesl0pynaxo/scicap_data.zip?dl=0'
|
| 21 |
+
3) unzip -UU 'scicap_data.zip?dl=0' && rm 'scicap_data.zip?dl=0'
|
| 22 |
+
|
| 23 |
+
Then, run conversion locally (make sure to install tensorflow-datasets for the `tfds` util):
|
| 24 |
+
|
| 25 |
+
cd big_vision/datasets
|
| 26 |
+
env TFDS_DATA_DIR=/tmp/tfds tfds build --datasets=scicap
|
| 27 |
+
|
| 28 |
+
Example to load:
|
| 29 |
+
|
| 30 |
+
import tensorflow_datasets as tfds
|
| 31 |
+
dataset = tfds.load('scicap', split='train', data_dir='/tmp/tfds')
|
| 32 |
+
"""
|
| 33 |
+
# pylint: enable=line-too-long
|
| 34 |
+
import enum
|
| 35 |
+
import functools
|
| 36 |
+
import json
|
| 37 |
+
import os
|
| 38 |
+
|
| 39 |
+
import tensorflow_datasets as tfds
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
_DESCRIPTION = """SciCap dataset."""
|
| 43 |
+
_CITATION = """
|
| 44 |
+
@article{hsu2021scicap,
|
| 45 |
+
title={SciCap: Generating captions for scientific figures},
|
| 46 |
+
author={Hsu, Ting-Yao and Giles, C Lee and Huang, Ting-Hao'Kenneth'},
|
| 47 |
+
journal={arXiv preprint arXiv:2110.11624},
|
| 48 |
+
year={2021}
|
| 49 |
+
}
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
# When running locally (recommended), copy files as above an use these:
|
| 53 |
+
_SCICAP_DIR = "/tmp/data/scicap/scicap_data"
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class ScicapSubset(enum.Enum):
|
| 57 |
+
"""Versions of the SciCap dataset."""
|
| 58 |
+
SINGLE_SENTENCE = "single_sentence"
|
| 59 |
+
FIRST_SENTENCE = "first_sentence"
|
| 60 |
+
LEQ_100_TOKENS = "leq_100_tokens"
|
| 61 |
+
|
| 62 |
+
_SPLITS_TO_GENERATE = ["train", "test", "val"]
|
| 63 |
+
_CONFIG_TO_IDS_PATH = {
|
| 64 |
+
(ScicapSubset.SINGLE_SENTENCE, True): "Single-Sentence-Caption/Yes-Subfig",
|
| 65 |
+
(ScicapSubset.SINGLE_SENTENCE, False): "Single-Sentence-Caption/No-Subfig",
|
| 66 |
+
(ScicapSubset.FIRST_SENTENCE, True): "First-Sentence/Yes-Subfig",
|
| 67 |
+
(ScicapSubset.FIRST_SENTENCE, False): "First-Sentence/No-Subfig",
|
| 68 |
+
(ScicapSubset.LEQ_100_TOKENS, True):
|
| 69 |
+
"Caption-No-More-Than-100-Tokens/Yes-Subfig",
|
| 70 |
+
(ScicapSubset.LEQ_100_TOKENS, False):
|
| 71 |
+
"Caption-No-More-Than-100-Tokens/No-Subfig",
|
| 72 |
+
}
|
| 73 |
+
_SUBFIG_TO_PATH = {
|
| 74 |
+
True: "SciCap-Yes-Subfig-Img", False: "SciCap-No-Subfig-Img"
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class ScicapConfig(tfds.core.BuilderConfig):
|
| 79 |
+
""""Configuration for SciCap caption length and subfigure inclusion."""
|
| 80 |
+
|
| 81 |
+
def __init__(self, *, subset: ScicapSubset, subfig: bool, **kwargs):
|
| 82 |
+
"""Parameters specifying how the dataset will be processed.
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
subset: Subset of the Scicap data (see enum above).
|
| 86 |
+
subfig: Whether or not figure with subfigures are included.
|
| 87 |
+
**kwargs: Passed on to the constructor of `BuilderConfig`.
|
| 88 |
+
"""
|
| 89 |
+
super(ScicapConfig, self).__init__(**kwargs)
|
| 90 |
+
self.subset = subset
|
| 91 |
+
self.subfig = subfig
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
@functools.cache
|
| 95 |
+
def _read_annotations(split: str, image_id: str):
|
| 96 |
+
"""Reads annotations for a single file."""
|
| 97 |
+
path = os.path.join(_SCICAP_DIR, "SciCap-Caption-All", split)
|
| 98 |
+
fname = os.path.join(path, image_id + ".json")
|
| 99 |
+
with open(fname, "r") as fin:
|
| 100 |
+
return json.load(fin)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class Scicap(tfds.core.GeneratorBasedBuilder):
|
| 104 |
+
"""DatasetBuilder for the SciCap dataset."""
|
| 105 |
+
|
| 106 |
+
VERSION = tfds.core.Version("1.0.0")
|
| 107 |
+
RELEASE_NOTES = {"1.0.0": "First release."}
|
| 108 |
+
|
| 109 |
+
BUILDER_CONFIGS = [
|
| 110 |
+
ScicapConfig(
|
| 111 |
+
name="single_sentence_subfig_yes",
|
| 112 |
+
description="Single sentence caption with subfigures allowed.",
|
| 113 |
+
subset=ScicapSubset.SINGLE_SENTENCE,
|
| 114 |
+
subfig=True
|
| 115 |
+
),
|
| 116 |
+
ScicapConfig(
|
| 117 |
+
name="single_sentence_subfig_no",
|
| 118 |
+
description="Single sentence caption with subfigures not allowed.",
|
| 119 |
+
subset=ScicapSubset.SINGLE_SENTENCE,
|
| 120 |
+
subfig=False
|
| 121 |
+
),
|
| 122 |
+
ScicapConfig(
|
| 123 |
+
name="first_sentence_subfig_yes",
|
| 124 |
+
description="First sentence of captions with subfigures allowed.",
|
| 125 |
+
subset=ScicapSubset.FIRST_SENTENCE,
|
| 126 |
+
subfig=True
|
| 127 |
+
),
|
| 128 |
+
ScicapConfig(
|
| 129 |
+
name="first_sentence_subfig_no",
|
| 130 |
+
description="First sentence of captions with subfigures not allowed.",
|
| 131 |
+
subset=ScicapSubset.FIRST_SENTENCE,
|
| 132 |
+
subfig=False
|
| 133 |
+
),
|
| 134 |
+
ScicapConfig(
|
| 135 |
+
name="leq_100_tokens_subfig_yes",
|
| 136 |
+
description="Captions with <= 100 tokens with subfigures allowed.",
|
| 137 |
+
subset=ScicapSubset.LEQ_100_TOKENS,
|
| 138 |
+
subfig=True
|
| 139 |
+
),
|
| 140 |
+
ScicapConfig(
|
| 141 |
+
name="leq_100_tokens_subfig_no",
|
| 142 |
+
description=("Captions with <= 100 tokens with subfigures"
|
| 143 |
+
" not allowed."),
|
| 144 |
+
subset=ScicapSubset.LEQ_100_TOKENS,
|
| 145 |
+
subfig=False
|
| 146 |
+
),
|
| 147 |
+
]
|
| 148 |
+
|
| 149 |
+
def _info(self):
|
| 150 |
+
"""Returns the metadata."""
|
| 151 |
+
|
| 152 |
+
return tfds.core.DatasetInfo(
|
| 153 |
+
builder=self,
|
| 154 |
+
description=_DESCRIPTION,
|
| 155 |
+
features=tfds.features.FeaturesDict({
|
| 156 |
+
"image/id": tfds.features.Text(),
|
| 157 |
+
"image/filename": tfds.features.Text(),
|
| 158 |
+
"image": tfds.features.Image(encoding_format="png"),
|
| 159 |
+
"caption/originally_extracted": tfds.features.Text(),
|
| 160 |
+
"caption/lowercase_and_token_and_remove_figure_index":
|
| 161 |
+
tfds.features.Text(),
|
| 162 |
+
"caption/normalized/basic_num": tfds.features.Text(),
|
| 163 |
+
"caption/normalized/advanced_equation_bracket":
|
| 164 |
+
tfds.features.Text(),
|
| 165 |
+
}),
|
| 166 |
+
supervised_keys=None,
|
| 167 |
+
homepage="https://github.com/tingyaohsu/SciCap",
|
| 168 |
+
citation=_CITATION,
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
def _split_generators(self, dl_manager: tfds.download.DownloadManager):
|
| 172 |
+
"""Returns SplitGenerators."""
|
| 173 |
+
return {split: self._generate_examples(split)
|
| 174 |
+
for split in _SPLITS_TO_GENERATE}
|
| 175 |
+
|
| 176 |
+
def _generate_examples(self, split: str):
|
| 177 |
+
"""Yields (key, example) tuples from test set."""
|
| 178 |
+
config_path = _CONFIG_TO_IDS_PATH[
|
| 179 |
+
(self.builder_config.subset, self.builder_config.subfig)]
|
| 180 |
+
image_path = os.path.join(
|
| 181 |
+
_SCICAP_DIR, _SUBFIG_TO_PATH[self.builder_config.subfig], split)
|
| 182 |
+
id_list_fname = os.path.join(
|
| 183 |
+
_SCICAP_DIR, "List-of-Files-for-Each-Experiments",
|
| 184 |
+
config_path, split, "file_idx.json")
|
| 185 |
+
with open(id_list_fname, "r") as fin:
|
| 186 |
+
split_images = json.load(fin)
|
| 187 |
+
|
| 188 |
+
for fname in split_images:
|
| 189 |
+
assert fname.endswith(".png")
|
| 190 |
+
image_id = fname[:-len(".png")]
|
| 191 |
+
annotations = _read_annotations(split, image_id)
|
| 192 |
+
yield fname, {
|
| 193 |
+
"image/id": image_id,
|
| 194 |
+
"image/filename": fname,
|
| 195 |
+
"image": os.path.join(image_path, fname),
|
| 196 |
+
"caption/originally_extracted": annotations["0-originally-extracted"],
|
| 197 |
+
"caption/lowercase_and_token_and_remove_figure_index":
|
| 198 |
+
annotations["1-lowercase-and-token-and-remove-figure-index"][
|
| 199 |
+
"caption"],
|
| 200 |
+
"caption/normalized/basic_num": annotations["2-normalized"][
|
| 201 |
+
"2-1-basic-num"]["caption"],
|
| 202 |
+
"caption/normalized/advanced_equation_bracket":
|
| 203 |
+
annotations["2-normalized"][
|
| 204 |
+
"2-2-advanced-euqation-bracket"]["caption"]
|
| 205 |
+
}
|
Tipsomaly/model/big_vision/datasets/science_qa/science_qa.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# pylint: disable=line-too-long
|
| 16 |
+
r"""Implements ScienceQA train/val/test-set in TFDS structure.
|
| 17 |
+
|
| 18 |
+
First, download the science QA dataset from their website https://scienceqa.github.io/#download
|
| 19 |
+
- mkdir -p /tmp/data/ScienceQA_DATA
|
| 20 |
+
- From Google Drive: https://drive.google.com/corp/drive/folders/1w8imCXWYn2LxajmGeGH_g5DaL2rabHev
|
| 21 |
+
Then, run conversion locally (make sure to install tensorflow-datasets for the `tfds` util):
|
| 22 |
+
- cd big_vision/datasets
|
| 23 |
+
- env TFDS_DATA_DIR=/tmp/tfds tfds build --datasets=science_qa
|
| 24 |
+
|
| 25 |
+
Example to load:
|
| 26 |
+
|
| 27 |
+
import tensorflow_datasets as tfds
|
| 28 |
+
dataset = tfds.load(
|
| 29 |
+
'science_qa', split='train',
|
| 30 |
+
data_dir='/tmp/tfds')
|
| 31 |
+
|
| 32 |
+
"""
|
| 33 |
+
import json
|
| 34 |
+
import os
|
| 35 |
+
|
| 36 |
+
import numpy as np
|
| 37 |
+
import tensorflow_datasets as tfds
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
_DESCRIPTION = """Sci QA test-set."""
|
| 41 |
+
|
| 42 |
+
# pylint: disable=line-too-long
|
| 43 |
+
_CITATION = """
|
| 44 |
+
@inproceedings{lu2022learn,
|
| 45 |
+
title={Learn to Explain: Multimodal Reasoning via Thought Chains for Science Question Answering},
|
| 46 |
+
author={Lu, Pan and Mishra, Swaroop and Xia, Tony and Qiu, Liang and Chang, Kai-Wei and Zhu, Song-Chun and Tafjord, Oyvind and Clark, Peter and Ashwin Kalyan},
|
| 47 |
+
booktitle={The 36th Conference on Neural Information Processing Systems (NeurIPS)},
|
| 48 |
+
year={2022}
|
| 49 |
+
}
|
| 50 |
+
"""
|
| 51 |
+
# pylint: enable=line-too-long
|
| 52 |
+
|
| 53 |
+
# When running locally (recommended), copy files as above an use these:
|
| 54 |
+
_SCIQA_PATH = '/tmp/data/ScienceQA_DATA/'
|
| 55 |
+
# _IMAGE_COCO_PATH = '/tmp/data/val2014'
|
| 56 |
+
|
| 57 |
+
_ALPHABETS = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class ScienceQA(tfds.core.GeneratorBasedBuilder):
|
| 61 |
+
"""DatasetBuilder for ScienceQA dataset."""
|
| 62 |
+
|
| 63 |
+
VERSION = tfds.core.Version('1.0.0')
|
| 64 |
+
RELEASE_NOTES = {'1.0.0': 'First release.'}
|
| 65 |
+
|
| 66 |
+
def _info(self):
|
| 67 |
+
"""Returns the metadata."""
|
| 68 |
+
|
| 69 |
+
return tfds.core.DatasetInfo(
|
| 70 |
+
builder=self,
|
| 71 |
+
description=_DESCRIPTION,
|
| 72 |
+
features=tfds.features.FeaturesDict({
|
| 73 |
+
'question': tfds.features.Text(),
|
| 74 |
+
'choices': tfds.features.Sequence(tfds.features.Text()),
|
| 75 |
+
'answer': tfds.features.Scalar(np.int32),
|
| 76 |
+
'hint': tfds.features.Text(),
|
| 77 |
+
'task': tfds.features.Text(),
|
| 78 |
+
'grade': tfds.features.Text(),
|
| 79 |
+
'subject': tfds.features.Text(),
|
| 80 |
+
'topic': tfds.features.Text(),
|
| 81 |
+
'category': tfds.features.Text(),
|
| 82 |
+
'skill': tfds.features.Text(),
|
| 83 |
+
'lecture': tfds.features.Text(),
|
| 84 |
+
'solution': tfds.features.Text(),
|
| 85 |
+
'image': tfds.features.Image(encoding_format='png'),
|
| 86 |
+
'indexed_choices': tfds.features.Text(),
|
| 87 |
+
'indexed_answer': tfds.features.Text(),
|
| 88 |
+
}),
|
| 89 |
+
supervised_keys=None,
|
| 90 |
+
homepage='https://github.com/lupantech/ScienceQA/tree/main',
|
| 91 |
+
citation=_CITATION,
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
def _split_generators(self, dl_manager: tfds.download.DownloadManager):
|
| 95 |
+
"""Returns SplitGenerators."""
|
| 96 |
+
return {
|
| 97 |
+
split: self._generate_examples(split)
|
| 98 |
+
for split in ('train', 'test', 'val')
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
def _generate_examples(self, split):
|
| 102 |
+
"""Yields (key, example) tuples from test set."""
|
| 103 |
+
annot_fname = os.path.join(_SCIQA_PATH, 'problems.json')
|
| 104 |
+
|
| 105 |
+
with open(annot_fname, 'r') as f:
|
| 106 |
+
data = json.loads(f.read())
|
| 107 |
+
|
| 108 |
+
for k, v in data.items():
|
| 109 |
+
if v['split'] == split: # "split":"train"
|
| 110 |
+
image = v['image']
|
| 111 |
+
# Science QA contains the example without image as well. As this
|
| 112 |
+
# conversion is for VQA tasks, we dropped the examples without Image.
|
| 113 |
+
# TODO: Include the examples without image, and udpate the
|
| 114 |
+
# downstream pipeline to skip the examples without image, instead of
|
| 115 |
+
# doing it at pre-processing.
|
| 116 |
+
if image:
|
| 117 |
+
image = os.path.join(f'{_SCIQA_PATH}/{split}/{k}/', f'{image}')
|
| 118 |
+
else:
|
| 119 |
+
# image = None
|
| 120 |
+
continue
|
| 121 |
+
question = v['question']
|
| 122 |
+
choices = v['choices']
|
| 123 |
+
answer = v['answer']
|
| 124 |
+
hint = v['hint']
|
| 125 |
+
if not hint:
|
| 126 |
+
hint = 'N/A' # align with orignal github implementation
|
| 127 |
+
task = v['task']
|
| 128 |
+
grade = v['grade']
|
| 129 |
+
subject = v['subject']
|
| 130 |
+
topic = v['topic']
|
| 131 |
+
category = v['category']
|
| 132 |
+
skill = v['skill']
|
| 133 |
+
lecture = v['lecture']
|
| 134 |
+
solution = v['solution']
|
| 135 |
+
split = v['split']
|
| 136 |
+
indexed_choices = ', '.join(
|
| 137 |
+
f'({_ALPHABETS[i]}) {c}' for i, c in enumerate(choices)
|
| 138 |
+
)
|
| 139 |
+
indexed_answer = _ALPHABETS[int(answer)]
|
| 140 |
+
yield int(k), {
|
| 141 |
+
'question': question,
|
| 142 |
+
'choices': choices,
|
| 143 |
+
'answer': answer,
|
| 144 |
+
'hint': hint,
|
| 145 |
+
'task': task,
|
| 146 |
+
'grade': grade,
|
| 147 |
+
'subject': subject,
|
| 148 |
+
'topic': topic,
|
| 149 |
+
'category': category,
|
| 150 |
+
'skill': skill,
|
| 151 |
+
'lecture': lecture,
|
| 152 |
+
'solution': solution,
|
| 153 |
+
'image': image,
|
| 154 |
+
'indexed_choices': indexed_choices,
|
| 155 |
+
'indexed_answer': indexed_answer,
|
| 156 |
+
}
|
Tipsomaly/model/big_vision/datasets/screen2words/screen2words.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# pylint: disable=line-too-long
|
| 16 |
+
r"""Creates TFDS dataset for Screen2words.
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
Preparing the data:
|
| 20 |
+
1) mkdir /tmp/data/rico && cd /tmp/data/rico
|
| 21 |
+
2) wget https://storage.googleapis.com/crowdstf-rico-uiuc-4540/rico_dataset_v0.1/unique_uis.tar.gz
|
| 22 |
+
3) tar xvfz unique_uis.tar.gz && rm unique_uis.tar.gz
|
| 23 |
+
4) git clone https://github.com/google-research-datasets/screen2words.git
|
| 24 |
+
|
| 25 |
+
Then, run conversion locally (make sure to install tensorflow-datasets for the `tfds` util):
|
| 26 |
+
|
| 27 |
+
cd big_vision/datasets
|
| 28 |
+
env TFDS_DATA_DIR=/tmp/tfds tfds build --datasets=screen2words
|
| 29 |
+
|
| 30 |
+
Example to load:
|
| 31 |
+
|
| 32 |
+
import tensorflow_datasets as tfds
|
| 33 |
+
dataset = tfds.load('screen2_words', split='train', data_dir='/tmp/tfds')
|
| 34 |
+
"""
|
| 35 |
+
# pylint: enable=line-too-long
|
| 36 |
+
import collections
|
| 37 |
+
import csv
|
| 38 |
+
import os
|
| 39 |
+
|
| 40 |
+
import numpy as np
|
| 41 |
+
import tensorflow_datasets as tfds
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
_DESCRIPTION = """Screen2words dataset."""
|
| 45 |
+
_CITATION = """
|
| 46 |
+
@inproceedings{wang2021screen2words,
|
| 47 |
+
title={Screen2words: Automatic mobile UI summarization with multimodal
|
| 48 |
+
learning},
|
| 49 |
+
author={Wang, Bryan and
|
| 50 |
+
Li, Gang and
|
| 51 |
+
Zhou, Xin and
|
| 52 |
+
Chen, Zhourong and
|
| 53 |
+
Grossman, Tovi and
|
| 54 |
+
Li, Yang},
|
| 55 |
+
booktitle={The 34th Annual ACM Symposium on User Interface Software
|
| 56 |
+
and Technology},
|
| 57 |
+
pages={498--510},
|
| 58 |
+
year={2021}
|
| 59 |
+
}
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
# When running locally (recommended), copy files as above an use these:
|
| 63 |
+
_SCREEN2WORDS_DIR = "/tmp/data/rico/screen2words"
|
| 64 |
+
_RICO_DIR = "/tmp/data/rico/combined"
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
# (name, path) tuples for splits to be generated.
|
| 68 |
+
_SPLITS_TO_GENERATE = ["train", "dev", "test"]
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class Screen2Words(tfds.core.GeneratorBasedBuilder):
|
| 72 |
+
"""DatasetBuilder for the Screen2words dataset."""
|
| 73 |
+
|
| 74 |
+
VERSION = tfds.core.Version("1.0.0")
|
| 75 |
+
RELEASE_NOTES = {"1.0.0": "First release."}
|
| 76 |
+
|
| 77 |
+
def _info(self):
|
| 78 |
+
"""Returns the metadata."""
|
| 79 |
+
|
| 80 |
+
return tfds.core.DatasetInfo(
|
| 81 |
+
builder=self,
|
| 82 |
+
description=_DESCRIPTION,
|
| 83 |
+
features=tfds.features.FeaturesDict({
|
| 84 |
+
"image/id": tfds.features.Scalar(np.int32),
|
| 85 |
+
"image/filename": tfds.features.Text(),
|
| 86 |
+
"image": tfds.features.Image(encoding_format="jpeg"),
|
| 87 |
+
"summary": tfds.features.Sequence(tfds.features.Text()),
|
| 88 |
+
}),
|
| 89 |
+
supervised_keys=None,
|
| 90 |
+
homepage="https://github.com/google-research-datasets/screen2words",
|
| 91 |
+
citation=_CITATION,
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
def _split_generators(self, dl_manager: tfds.download.DownloadManager):
|
| 95 |
+
"""Returns SplitGenerators."""
|
| 96 |
+
return {split: self._generate_examples(split)
|
| 97 |
+
for split in _SPLITS_TO_GENERATE}
|
| 98 |
+
|
| 99 |
+
def _generate_examples(self, split: str):
|
| 100 |
+
"""Yields (key, example) tuples from test set."""
|
| 101 |
+
id_list_fname = os.path.join(
|
| 102 |
+
_SCREEN2WORDS_DIR, "split", f"{split}_screens.txt")
|
| 103 |
+
with open(id_list_fname, "r") as fin:
|
| 104 |
+
split_ids = fin.readlines()
|
| 105 |
+
|
| 106 |
+
summaries_fname = os.path.join(_SCREEN2WORDS_DIR, "screen_summaries.csv")
|
| 107 |
+
summaries = collections.defaultdict(list)
|
| 108 |
+
with open(summaries_fname, "r") as fin:
|
| 109 |
+
for entry in csv.DictReader(fin):
|
| 110 |
+
summaries[int(entry["screenId"])].append(entry["summary"])
|
| 111 |
+
|
| 112 |
+
for line in split_ids:
|
| 113 |
+
line = line.strip()
|
| 114 |
+
image_id = int(line)
|
| 115 |
+
yield image_id, {
|
| 116 |
+
"image/id": image_id,
|
| 117 |
+
"image/filename": f"{image_id}.jpg",
|
| 118 |
+
"image": os.path.join(_RICO_DIR, f"{image_id}.jpg"),
|
| 119 |
+
"summary": summaries[image_id],
|
| 120 |
+
}
|
Tipsomaly/model/big_vision/datasets/stvqa/stvqa.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# pylint: disable=line-too-long
|
| 16 |
+
r"""Implements ST-VQA dataset in TFDS.
|
| 17 |
+
|
| 18 |
+
It's small data, so simple to run locally.
|
| 19 |
+
First, download and unzip the dataset from https://rrc.cvc.uab.es/?ch=11
|
| 20 |
+
and place it in /tmp/data/stvqa.
|
| 21 |
+
|
| 22 |
+
Then, run conversion locally (make sure to install tensorflow-datasets for the `tfds` util):
|
| 23 |
+
|
| 24 |
+
cd third_party/py/big_vision/datasets
|
| 25 |
+
env TFDS_DATA_DIR=/tmp/tfds tfds build --datasets=stvqa
|
| 26 |
+
|
| 27 |
+
Example to load:
|
| 28 |
+
|
| 29 |
+
import tensorflow_datasets as tfds
|
| 30 |
+
dataset = tfds.load('stvqa', split='train', data_dir='/tmp/tfds')
|
| 31 |
+
|
| 32 |
+
Dataset splits:
|
| 33 |
+
train: 23446 examples/questions (subset of original train)
|
| 34 |
+
val: 2628 examples/questions (subset of original train)
|
| 35 |
+
test: 4070 examples/questions (no answers)
|
| 36 |
+
|
| 37 |
+
Note: original source data has no val/holdout split, and we therefore split the
|
| 38 |
+
original train split (26074 examples/questions) by ourselves into train & val
|
| 39 |
+
splits.
|
| 40 |
+
|
| 41 |
+
Recommended training splits:
|
| 42 |
+
train: train
|
| 43 |
+
minitrain: train[:5%]
|
| 44 |
+
eval: val
|
| 45 |
+
fulltrain: train+val
|
| 46 |
+
"""
|
| 47 |
+
import json
|
| 48 |
+
import os
|
| 49 |
+
|
| 50 |
+
from big_vision.datasets.stvqa import val_ids
|
| 51 |
+
import numpy as np
|
| 52 |
+
import tensorflow_datasets as tfds
|
| 53 |
+
|
| 54 |
+
_VAL_IDS = val_ids.PSEUDO_VAL_IMAGE_PATHS
|
| 55 |
+
|
| 56 |
+
_DESCRIPTION = """ST-VQA dataset."""
|
| 57 |
+
|
| 58 |
+
# pylint: disable=line-too-long
|
| 59 |
+
_CITATION = """
|
| 60 |
+
@inproceedings{Biten_2019,
|
| 61 |
+
title={Scene Text Visual Question Answering},
|
| 62 |
+
url={http://dx.doi.org/10.1109/ICCV.2019.00439},
|
| 63 |
+
DOI={10.1109/iccv.2019.00439},
|
| 64 |
+
booktitle={2019 IEEE/CVF International Conference on Computer Vision (ICCV)},
|
| 65 |
+
publisher={IEEE},
|
| 66 |
+
author={Biten, Ali Furkan and Tito, Ruben and Mafla, Andres and Gomez, Lluis and Rusinol, Marcal and Jawahar, C.V. and Valveny, Ernest and Karatzas, Dimosthenis},
|
| 67 |
+
year={2019},
|
| 68 |
+
month=oct }
|
| 69 |
+
"""
|
| 70 |
+
# pylint: enable=line-too-long
|
| 71 |
+
|
| 72 |
+
# When running locally (recommended), copy files as above an use these:
|
| 73 |
+
_STVQA_PATH = '/tmp/data/stvqa/'
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class Stvqa(tfds.core.GeneratorBasedBuilder):
|
| 77 |
+
"""DatasetBuilder for ST-VQA dataset."""
|
| 78 |
+
|
| 79 |
+
VERSION = tfds.core.Version('1.2.0')
|
| 80 |
+
RELEASE_NOTES = {
|
| 81 |
+
'1.0.0': 'First release.',
|
| 82 |
+
'1.1.0': 'Switch to COCO high-res images and lower-case answers.',
|
| 83 |
+
'1.2.0': 'Rename pseudo splits and remove lower-case answers.',
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
def _info(self):
|
| 87 |
+
"""Returns the metadata."""
|
| 88 |
+
|
| 89 |
+
return tfds.core.DatasetInfo(
|
| 90 |
+
builder=self,
|
| 91 |
+
description=_DESCRIPTION,
|
| 92 |
+
features=tfds.features.FeaturesDict({
|
| 93 |
+
'question_id': tfds.features.Scalar(np.int32),
|
| 94 |
+
'filename': tfds.features.Text(),
|
| 95 |
+
'image': tfds.features.Image(encoding_format='jpeg'),
|
| 96 |
+
'question': tfds.features.Text(),
|
| 97 |
+
'answers': tfds.features.Sequence(tfds.features.Text()),
|
| 98 |
+
}),
|
| 99 |
+
supervised_keys=None,
|
| 100 |
+
homepage='https://rrc.cvc.uab.es/?ch=11',
|
| 101 |
+
citation=_CITATION,
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
def _split_generators(self, dl_manager: tfds.download.DownloadManager):
|
| 105 |
+
"""Returns SplitGenerators."""
|
| 106 |
+
return {split: self._generate_examples(split)
|
| 107 |
+
for split in ('train', 'val', 'test')}
|
| 108 |
+
|
| 109 |
+
def _generate_examples(self, split):
|
| 110 |
+
"""Yields (key, example) tuples."""
|
| 111 |
+
src_split = 'test' if split == 'test' else 'train'
|
| 112 |
+
annot_fname = os.path.join(_STVQA_PATH, f'{src_split}_task_3.json')
|
| 113 |
+
images_path = f'{src_split}{"_task3" if src_split == "test" else ""}_images'
|
| 114 |
+
|
| 115 |
+
with open(annot_fname, 'r') as f:
|
| 116 |
+
data = json.loads(f.read())
|
| 117 |
+
|
| 118 |
+
for x in data['data']:
|
| 119 |
+
if split == 'val' and x['file_path'] not in _VAL_IDS:
|
| 120 |
+
continue
|
| 121 |
+
elif split == 'train' and x['file_path'] in _VAL_IDS:
|
| 122 |
+
continue
|
| 123 |
+
image_path = os.path.join(_STVQA_PATH, images_path, x['file_path'])
|
| 124 |
+
# Always use high-res COCO images from train2014 directory.
|
| 125 |
+
if x['file_path'].startswith('coco-text'):
|
| 126 |
+
image_path = image_path.replace(os.path.join(images_path, 'coco-text'),
|
| 127 |
+
'train2014')
|
| 128 |
+
yield x['question_id'], {
|
| 129 |
+
'question_id': x['question_id'],
|
| 130 |
+
'filename': x['file_path'],
|
| 131 |
+
'image': image_path,
|
| 132 |
+
'question': x['question'],
|
| 133 |
+
'answers': x.get('answers', []),
|
| 134 |
+
}
|
Tipsomaly/model/big_vision/datasets/tallyqa/tallyqa.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Import TallyQA into TFDS format. Uses Visual Genome and COCO images.
|
| 16 |
+
|
| 17 |
+
It's small data, so simple to run locally. First, download all the data:
|
| 18 |
+
|
| 19 |
+
mkdir /tmp/data/ ; cd /tmp/data
|
| 20 |
+
wget http://images.cocodataset.org/zips/{train2014,val2014}.zip
|
| 21 |
+
wget https://cs.stanford.edu/people/rak248/VG_100K_2/images.zip
|
| 22 |
+
wget https://cs.stanford.edu/people/rak248/VG_100K_2/images2.zip
|
| 23 |
+
wget https://github.com/manoja328/tallyqa/blob/master/tallyqa.zip?raw=true
|
| 24 |
+
unzip *.zip
|
| 25 |
+
|
| 26 |
+
Then, update the PATHs below and run conversion locally like so (make sure to
|
| 27 |
+
install tensorflow-datasets for the `tfds` util):
|
| 28 |
+
|
| 29 |
+
cd big_vision/datasets
|
| 30 |
+
env TFDS_DATA_DIR=/tmp/tfds tfds build --datasets=tallyqa
|
| 31 |
+
|
| 32 |
+
Example to load:
|
| 33 |
+
import tensorflow_datasets as tfds
|
| 34 |
+
dataset = tfds.load('tallyqa', split='train', data_dir='/tmp/tfds')
|
| 35 |
+
|
| 36 |
+
The test split distinguishes between simple and complex questions. The train
|
| 37 |
+
split does not contain this information. We therefore set issimple to `-1` in
|
| 38 |
+
the train split to indicate it is not known.
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
import json
|
| 42 |
+
|
| 43 |
+
import numpy as np
|
| 44 |
+
import tensorflow_datasets as tfds
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
_TALLYQA_PATH = '/tmp/data/tallyQA/'
|
| 48 |
+
_VISUAL_GENOME_PATH = '/tmp/data/visual_genome/'
|
| 49 |
+
|
| 50 |
+
_COCO_PATH = '/tmp/data/coco/'
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
_DESCRIPTION = """
|
| 54 |
+
TallyQA: Answering Complex Counting Questions
|
| 55 |
+
Most counting questions in visual question answering (VQA) datasets are simple
|
| 56 |
+
and require no more than object detection. Here, we study algorithms for complex
|
| 57 |
+
counting questions that involve relationships between objects, attribute
|
| 58 |
+
identification, reasoning, and more. To do this, we created TallyQA, the world's
|
| 59 |
+
largest dataset for open-ended counting.
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
_CITATION = """
|
| 63 |
+
@inproceedings{acharya2019tallyqa,
|
| 64 |
+
title={TallyQA: Answering Complex Counting Questions},
|
| 65 |
+
author={Acharya, Manoj and Kafle, Kushal and Kanan, Christopher},
|
| 66 |
+
booktitle={AAAI},
|
| 67 |
+
year={2019}
|
| 68 |
+
}
|
| 69 |
+
"""
|
| 70 |
+
|
| 71 |
+
_HOMEPAGE = 'https://github.com/manoja328/TallyQA_dataset'
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class TallyQA(tfds.core.GeneratorBasedBuilder):
|
| 75 |
+
"""Import TallyQA dataset."""
|
| 76 |
+
|
| 77 |
+
VERSION = tfds.core.Version('1.0.0')
|
| 78 |
+
RELEASE_NOTES = {'1.0.0': 'Initial release.'}
|
| 79 |
+
MANUAL_DOWNLOAD_INSTRUCTIONS = """
|
| 80 |
+
There are three parts which should be downloaded:
|
| 81 |
+
* TallyQA (train / test json files)
|
| 82 |
+
* Visual Genome images (needed for train and test split)
|
| 83 |
+
* COCO (2014) train / val images (only needed for train split)
|
| 84 |
+
"""
|
| 85 |
+
|
| 86 |
+
def _info(self) -> tfds.core.DatasetInfo:
|
| 87 |
+
"""Returns the dataset metadata."""
|
| 88 |
+
features = tfds.features.FeaturesDict({
|
| 89 |
+
'image': tfds.features.Image(shape=(None, None, 3)),
|
| 90 |
+
'image_id': tfds.features.Scalar(dtype=np.int32),
|
| 91 |
+
'image_source': tfds.features.Text(),
|
| 92 |
+
'question': tfds.features.Text(),
|
| 93 |
+
'question_id': tfds.features.Scalar(dtype=np.int32),
|
| 94 |
+
'answer': tfds.features.Scalar(dtype=np.int32),
|
| 95 |
+
'issimple': tfds.features.Scalar(dtype=np.int32),
|
| 96 |
+
})
|
| 97 |
+
|
| 98 |
+
return tfds.core.DatasetInfo(
|
| 99 |
+
builder=self,
|
| 100 |
+
features=features,
|
| 101 |
+
description=_DESCRIPTION,
|
| 102 |
+
supervised_keys=None,
|
| 103 |
+
homepage=_HOMEPAGE,
|
| 104 |
+
citation=_CITATION,
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
def _split_generators(self, dl_manager: tfds.download.DownloadManager) -> ...:
|
| 108 |
+
"""Call the function which defines the splits."""
|
| 109 |
+
del dl_manager
|
| 110 |
+
return {
|
| 111 |
+
'train': self._generate_examples(split='train'),
|
| 112 |
+
'test': self._generate_examples(split='test'),
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
def _generate_examples(self, split: str) -> ...:
|
| 116 |
+
tally_json_file = f'{_TALLYQA_PATH}/{split}.json'
|
| 117 |
+
with open(tally_json_file, 'r') as f:
|
| 118 |
+
tally_json = json.load(f)
|
| 119 |
+
|
| 120 |
+
for tally_qa in tally_json:
|
| 121 |
+
# The TallyQA images come from two sources: Visual Genome and COCO.
|
| 122 |
+
# Determine the correct dataset by inspecting the prefix.
|
| 123 |
+
filepath = tally_qa['image']
|
| 124 |
+
if filepath.startswith('VG_100K'):
|
| 125 |
+
filepath = _VISUAL_GENOME_PATH + filepath
|
| 126 |
+
elif filepath.startswith('train2014') or filepath.startswith('val2014'):
|
| 127 |
+
filepath = _COCO_PATH + filepath
|
| 128 |
+
else:
|
| 129 |
+
raise ValueError(f'Unknown image path: {filepath}')
|
| 130 |
+
|
| 131 |
+
tally_qa_dict = {
|
| 132 |
+
'image': filepath,
|
| 133 |
+
'image_id': tally_qa['image_id'],
|
| 134 |
+
'image_source': tally_qa['data_source'],
|
| 135 |
+
'question': tally_qa['question'],
|
| 136 |
+
'question_id': tally_qa['question_id'],
|
| 137 |
+
'answer': int(tally_qa['answer']),
|
| 138 |
+
}
|
| 139 |
+
if split == 'test':
|
| 140 |
+
# Field only present in test split.
|
| 141 |
+
tally_qa_dict.update({'issimple': tally_qa['issimple']})
|
| 142 |
+
else:
|
| 143 |
+
# In the train split, we set issimple to -1 to indicate it is not known.
|
| 144 |
+
tally_qa_dict.update({'issimple': -1})
|
| 145 |
+
tally_qa_id = f'{tally_qa_dict["image_id"]} / {tally_qa_dict["question_id"]}' # pylint: disable=line-too-long
|
| 146 |
+
yield tally_qa_id, tally_qa_dict
|
Tipsomaly/model/big_vision/datasets/textcaps/textcaps.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# pylint: disable=line-too-long
|
| 16 |
+
r"""Implements textcaps val-set in TFDS structure.
|
| 17 |
+
|
| 18 |
+
It's small data, so simple to run locally. First, copy the data to local disk:
|
| 19 |
+
|
| 20 |
+
mkdir -p /tmp/data/textcaps
|
| 21 |
+
cd /tmp/data/textcaps
|
| 22 |
+
curl -O https://dl.fbaipublicfiles.com/textvqa/data/textcaps/TextCaps_0.1_train.json
|
| 23 |
+
curl -O https://dl.fbaipublicfiles.com/textvqa/data/textcaps/TextCaps_0.1_val.json
|
| 24 |
+
curl -O https://dl.fbaipublicfiles.com/textvqa/data/textcaps/TextCaps_0.1_test.json
|
| 25 |
+
curl -O https://dl.fbaipublicfiles.com/textvqa/images/train_val_images.zip
|
| 26 |
+
curl -O https://dl.fbaipublicfiles.com/textvqa/images/test_images.zip
|
| 27 |
+
unzip train_val_images.zip
|
| 28 |
+
rm train_val_images.zip
|
| 29 |
+
unzip test_images.zip
|
| 30 |
+
rm test_images.zip
|
| 31 |
+
|
| 32 |
+
Then, run conversion locally (make sure to install tensorflow-datasets for the
|
| 33 |
+
`tfds` util):
|
| 34 |
+
|
| 35 |
+
cd big_vision/datasets
|
| 36 |
+
env TFDS_DATA_DIR=/tmp/tfds tfds build --datasets=textcaps
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
Example to load:
|
| 40 |
+
|
| 41 |
+
import tensorflow_datasets as tfds
|
| 42 |
+
dataset = tfds.load('text_caps', split='val', data_dir='/tmp/tfds')
|
| 43 |
+
"""
|
| 44 |
+
import collections
|
| 45 |
+
import json
|
| 46 |
+
import os
|
| 47 |
+
|
| 48 |
+
from absl import logging
|
| 49 |
+
import numpy as np
|
| 50 |
+
import tensorflow_datasets as tfds
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
_DESCRIPTION = """TextCaps dataset."""
|
| 54 |
+
|
| 55 |
+
# pylint: disable=line-too-long
|
| 56 |
+
_CITATION = (
|
| 57 |
+
'@inproceedings{sidorov2019textcaps,'
|
| 58 |
+
'title={TextCaps: a Dataset for Image Captioningwith Reading Comprehension},'
|
| 59 |
+
'author={Sidorov, Oleksii and Hu, Ronghang and Rohrbach, Marcus and Singh, Amanpreet},'
|
| 60 |
+
'journal={European Conference on Computer Vision},'
|
| 61 |
+
'year={2020}}')
|
| 62 |
+
# pylint: enable=line-too-long
|
| 63 |
+
|
| 64 |
+
# When running locally (recommended), copy files as above an use these:
|
| 65 |
+
_FILEPATH = '/tmp/data/textcaps/'
|
| 66 |
+
_TRAIN_FILES = '/tmp/data/textcaps/TextCaps_0.1_train.json'
|
| 67 |
+
_VAL_FILES = '/tmp/data/textcaps/TextCaps_0.1_val.json'
|
| 68 |
+
_TEST_FILES = '/tmp/data/textcaps/TextCaps_0.1_test.json'
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class TextCaps(tfds.core.GeneratorBasedBuilder):
|
| 72 |
+
"""DatasetBuilder for TextCaps dataset."""
|
| 73 |
+
|
| 74 |
+
VERSION = tfds.core.Version('1.0.0')
|
| 75 |
+
RELEASE_NOTES = {
|
| 76 |
+
'1.0.0': 'Initial release.',
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
def _info(self) -> tfds.core.DatasetInfo:
|
| 80 |
+
"""Returns the dataset metadata.
|
| 81 |
+
|
| 82 |
+
(tfds.core.DatasetInfo object)
|
| 83 |
+
These are the features of your dataset like images, labels, etc.
|
| 84 |
+
"""
|
| 85 |
+
|
| 86 |
+
return tfds.core.DatasetInfo(
|
| 87 |
+
builder=self,
|
| 88 |
+
description=_DESCRIPTION,
|
| 89 |
+
features=tfds.features.FeaturesDict({
|
| 90 |
+
'image/id': tfds.features.Text(),
|
| 91 |
+
'image_filepath': tfds.features.Text(),
|
| 92 |
+
'image': tfds.features.Image(encoding_format='jpeg'),
|
| 93 |
+
'texts': tfds.features.Sequence(tfds.features.Text()),
|
| 94 |
+
}),
|
| 95 |
+
supervised_keys=None, # Set to `None` to disable
|
| 96 |
+
homepage='https://textvqa.org/textcaps/',
|
| 97 |
+
citation=_CITATION,
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
def _split_generators(self, dl_manager: tfds.download.DownloadManager):
|
| 101 |
+
"""Returns SplitGenerators."""
|
| 102 |
+
def group_by_id(data, image_dir):
|
| 103 |
+
id_to_example = collections.defaultdict(list)
|
| 104 |
+
for ex in data:
|
| 105 |
+
id_to_example[ex['image_id']].append(ex)
|
| 106 |
+
|
| 107 |
+
for k, exs in id_to_example.items():
|
| 108 |
+
image_ids, image_names, texts = [], [], []
|
| 109 |
+
for ex in exs:
|
| 110 |
+
image_ids.append(ex['image_id'])
|
| 111 |
+
image_names.append(ex['image_name'])
|
| 112 |
+
if ex.get('caption_str'):
|
| 113 |
+
texts.append(ex.get('caption_str'))
|
| 114 |
+
assert len(set(image_ids)) == 1
|
| 115 |
+
assert len(set(image_names)) == 1
|
| 116 |
+
image_filepath = os.path.join(
|
| 117 |
+
_FILEPATH, image_dir, str(image_names[0])+'.jpg')
|
| 118 |
+
id_to_example[k] = {
|
| 119 |
+
'image/id': image_ids[0],
|
| 120 |
+
'image_filepath': image_filepath,
|
| 121 |
+
'image': image_filepath,
|
| 122 |
+
'texts': texts,
|
| 123 |
+
}
|
| 124 |
+
return id_to_example
|
| 125 |
+
|
| 126 |
+
# Returns the Dict[split names, Iterator[Key, Example]]
|
| 127 |
+
with open(_TRAIN_FILES) as f:
|
| 128 |
+
train_data = group_by_id(json.load(f)['data'], 'train_images')
|
| 129 |
+
with open(_VAL_FILES) as f:
|
| 130 |
+
val_data = group_by_id(json.load(f)['data'], 'train_images')
|
| 131 |
+
with open(_TEST_FILES) as f:
|
| 132 |
+
test_data = group_by_id(json.load(f)['data'], 'test_images')
|
| 133 |
+
return {
|
| 134 |
+
'train': self._generate_examples(train_data),
|
| 135 |
+
'val': self._generate_examples(val_data),
|
| 136 |
+
'test': self._generate_examples(test_data),
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
def _generate_examples(self, data):
|
| 140 |
+
"""Generate a tf.Example object.
|
| 141 |
+
|
| 142 |
+
This contains the image, objects, attributes, regions and relationships.
|
| 143 |
+
|
| 144 |
+
Args:
|
| 145 |
+
data: a dictionary with the image/id.
|
| 146 |
+
|
| 147 |
+
Yields:
|
| 148 |
+
(key, example) tuples from dataset. The example has format specified in
|
| 149 |
+
the above DatasetInfo.
|
| 150 |
+
"""
|
| 151 |
+
for k, v in data.items():
|
| 152 |
+
yield k, v
|
Tipsomaly/model/big_vision/datasets/textvqa/textvqa.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# pylint: disable=line-too-long
|
| 16 |
+
r"""Implements textvqa in TFDS structure.
|
| 17 |
+
|
| 18 |
+
It's small data, so simple to run locally. First, copy the data to local disk:
|
| 19 |
+
|
| 20 |
+
mkdir -p /tmp/data/textvqa
|
| 21 |
+
cd /tmp/data/textvqa
|
| 22 |
+
curl -O https://dl.fbaipublicfiles.com/textvqa/images/train_val_images.zip
|
| 23 |
+
curl -O https://dl.fbaipublicfiles.com/textvqa/images/test_images.zip
|
| 24 |
+
curl -O https://dl.fbaipublicfiles.com/textvqa/data/TextVQA_0.5.1_train.json
|
| 25 |
+
curl -O https://dl.fbaipublicfiles.com/textvqa/data/TextVQA_0.5.1_val.json
|
| 26 |
+
curl -O https://dl.fbaipublicfiles.com/textvqa/data/TextVQA_0.5.1_test.json
|
| 27 |
+
# The Rosetta_OCR files are probably not needed.
|
| 28 |
+
# curl -O https://dl.fbaipublicfiles.com/textvqa/data/TextVQA_Rosetta_OCR_v0.2_train.json
|
| 29 |
+
# curl -O https://dl.fbaipublicfiles.com/textvqa/data/TextVQA_Rosetta_OCR_v0.2_val.json
|
| 30 |
+
# curl -O https://dl.fbaipublicfiles.com/textvqa/data/TextVQA_Rosetta_OCR_v0.2_test.json
|
| 31 |
+
unzip train_val_images.zip
|
| 32 |
+
rm train_val_images.zip
|
| 33 |
+
unzip test_images.zip
|
| 34 |
+
rm test_images.zip
|
| 35 |
+
# Background: at https://textvqa.org/dataset/ it says:
|
| 36 |
+
# "Note: Some of the images in OpenImages are rotated,
|
| 37 |
+
# please make sure to check the Rotation field in the Image IDs files
|
| 38 |
+
# for train and test."
|
| 39 |
+
curl -O https://storage.googleapis.com/openimages/2018_04/train/train-images-boxable-with-rotation.csv
|
| 40 |
+
curl -O https://storage.googleapis.com/openimages/2018_04/test/test-images-with-rotation.csv
|
| 41 |
+
mv train-images-boxable-with-rotation.csv train_images/rotation.csv
|
| 42 |
+
mv test-images-with-rotation.csv test_images/rotation.csv
|
| 43 |
+
|
| 44 |
+
Then, run conversion locally (make sure to install tensorflow-datasets for the `tfds` util):
|
| 45 |
+
|
| 46 |
+
cd big_vision/datasets
|
| 47 |
+
env TFDS_DATA_DIR=/tmp/tfds tfds build --datasets=textvqa
|
| 48 |
+
|
| 49 |
+
Example to load:
|
| 50 |
+
|
| 51 |
+
import tensorflow_datasets as tfds
|
| 52 |
+
dataset = tfds.load('textvqa', split='train', data_dir='/tmp/tfds')
|
| 53 |
+
"""
|
| 54 |
+
import json
|
| 55 |
+
import os
|
| 56 |
+
|
| 57 |
+
from absl import logging
|
| 58 |
+
import numpy as np
|
| 59 |
+
import pandas as pd
|
| 60 |
+
import tensorflow as tf
|
| 61 |
+
import tensorflow_datasets as tfds
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
_DESCRIPTION = """TextVqa dataset."""
|
| 65 |
+
|
| 66 |
+
# pylint: disable=line-too-long
|
| 67 |
+
_CITATION = (
|
| 68 |
+
'@inproceedings{singh2019towards,'
|
| 69 |
+
'title={Towards VQA Models That Can Read},'
|
| 70 |
+
'author={Singh, Amanpreet and Natarjan, Vivek and Shah, Meet and Jiang, Yu and Chen, Xinlei and Parikh, Devi and Rohrbach, Marcus},'
|
| 71 |
+
'booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},'
|
| 72 |
+
'pages={8317-8326},'
|
| 73 |
+
'year={2019}}'
|
| 74 |
+
)
|
| 75 |
+
# pylint: enable=line-too-long
|
| 76 |
+
|
| 77 |
+
# When running locally (recommended), copy files as above and use these:
|
| 78 |
+
_FILEPATH = '/tmp/data/textvqa/'
|
| 79 |
+
_TRAIN_FILES = '/tmp/data/textvqa/TextVQA_0.5.1_train.json'
|
| 80 |
+
_VAL_FILES = '/tmp/data/textvqa/TextVQA_0.5.1_val.json'
|
| 81 |
+
_TEST_FILES = '/tmp/data/textvqa/TextVQA_0.5.1_test.json'
|
| 82 |
+
_ROTATION_CSV = 'rotation.csv'
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class TextVqa(tfds.core.GeneratorBasedBuilder):
|
| 86 |
+
"""DatasetBuilder for textvqa dataset."""
|
| 87 |
+
|
| 88 |
+
VERSION = tfds.core.Version('1.0.1')
|
| 89 |
+
RELEASE_NOTES = {
|
| 90 |
+
'1.0.0': 'Initial release.',
|
| 91 |
+
'1.0.1': 'Undo rotation for known rotated images.',
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
def _info(self) -> tfds.core.DatasetInfo:
|
| 95 |
+
"""Returns the dataset metadata.
|
| 96 |
+
|
| 97 |
+
(tfds.core.DatasetInfo object)
|
| 98 |
+
These are the features of your dataset like images, labels, etc.
|
| 99 |
+
"""
|
| 100 |
+
|
| 101 |
+
return tfds.core.DatasetInfo(
|
| 102 |
+
builder=self,
|
| 103 |
+
description=_DESCRIPTION,
|
| 104 |
+
features=tfds.features.FeaturesDict({
|
| 105 |
+
'image/id': tfds.features.Scalar(np.int32),
|
| 106 |
+
'image_filepath': tfds.features.Text(),
|
| 107 |
+
'image': tfds.features.Image(encoding_format='jpeg'),
|
| 108 |
+
'question_id': tfds.features.Scalar(np.int32),
|
| 109 |
+
'question': tfds.features.Text(),
|
| 110 |
+
'answers': tfds.features.Sequence(tfds.features.Text()),
|
| 111 |
+
}),
|
| 112 |
+
supervised_keys=None, # Set to `None` to disable
|
| 113 |
+
homepage='https://textvqa.org/',
|
| 114 |
+
citation=_CITATION,
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
def _split_generators(self, dl_manager: tfds.download.DownloadManager):
|
| 118 |
+
"""Returns SplitGenerators."""
|
| 119 |
+
def json_to_examples(data, image_dir):
|
| 120 |
+
# Load rotation csv.
|
| 121 |
+
logging.info('Processing %d items in %s', len(data), image_dir)
|
| 122 |
+
rot = pd.read_csv(os.path.join(_FILEPATH, image_dir, _ROTATION_CSV))
|
| 123 |
+
rotation_by_id = {}
|
| 124 |
+
for row in rot.itertuples():
|
| 125 |
+
rotation = int(row.Rotation) if not np.isnan(row.Rotation) else 0
|
| 126 |
+
rotation_by_id[row.ImageID] = rotation
|
| 127 |
+
|
| 128 |
+
examples = {}
|
| 129 |
+
for v in data:
|
| 130 |
+
image_id = str(v['image_id'])
|
| 131 |
+
image_filepath = os.path.join(_FILEPATH, image_dir, image_id + '.jpg')
|
| 132 |
+
question_id = v['question_id']
|
| 133 |
+
examples[question_id] = {
|
| 134 |
+
'image/id': question_id,
|
| 135 |
+
'image_filepath': image_filepath,
|
| 136 |
+
'image': image_filepath,
|
| 137 |
+
'rotation': rotation_by_id[image_id],
|
| 138 |
+
'question_id': question_id,
|
| 139 |
+
'question': v['question'],
|
| 140 |
+
'answers': v.get('answers', []), # No answers in test set.
|
| 141 |
+
}
|
| 142 |
+
return examples
|
| 143 |
+
|
| 144 |
+
# Returns the Dict[split names, Iterator[Key, Example]]
|
| 145 |
+
with open(_TRAIN_FILES) as f:
|
| 146 |
+
train_data = json_to_examples(json.load(f)['data'], 'train_images')
|
| 147 |
+
with open(_VAL_FILES) as f:
|
| 148 |
+
# Validation images are stored in the train_images folder.
|
| 149 |
+
val_data = json_to_examples(json.load(f)['data'], 'train_images')
|
| 150 |
+
with open(_TEST_FILES) as f:
|
| 151 |
+
test_data = json_to_examples(json.load(f)['data'], 'test_images')
|
| 152 |
+
return {
|
| 153 |
+
'train': self._generate_examples(train_data),
|
| 154 |
+
'val': self._generate_examples(val_data),
|
| 155 |
+
'test': self._generate_examples(test_data),
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
def _generate_examples(self, data):
|
| 159 |
+
"""Generate a tf.Example object.
|
| 160 |
+
|
| 161 |
+
Args:
|
| 162 |
+
data: a dictionary with the image/id.
|
| 163 |
+
|
| 164 |
+
Yields:
|
| 165 |
+
(key, example) tuples from dataset. The example has format specified in
|
| 166 |
+
the above DatasetInfo.
|
| 167 |
+
"""
|
| 168 |
+
for k, v in data.items():
|
| 169 |
+
# If the image is rotated, we undo the rotation here and re-encode.
|
| 170 |
+
image_bytes = open(v['image_filepath'], 'rb').read()
|
| 171 |
+
if v['rotation'] != 0:
|
| 172 |
+
rotation = v['rotation']
|
| 173 |
+
assert rotation % 90 == 0
|
| 174 |
+
turns = int(rotation / 90)
|
| 175 |
+
image = tf.image.decode_jpeg(image_bytes)
|
| 176 |
+
image_bytes = tf.io.encode_jpeg(
|
| 177 |
+
tf.image.rot90(image, turns), quality=100
|
| 178 |
+
).numpy()
|
| 179 |
+
# If no rotation was needed, we just pass along the unchanged bytes.
|
| 180 |
+
v['image'] = image_bytes
|
| 181 |
+
|
| 182 |
+
# Now all rotation should have been accounted for. And we don't want to
|
| 183 |
+
# pass on the (now obsolete) rotation info as features.
|
| 184 |
+
del v['rotation']
|
| 185 |
+
|
| 186 |
+
yield k, v
|
Tipsomaly/model/big_vision/datasets/vizwizvqa/vizwizvqa.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# pylint: disable=line-too-long
|
| 16 |
+
r"""Implements VizWizVQA dataset in TFDS structure.
|
| 17 |
+
|
| 18 |
+
It's small data, so simple to run locally. First, copy the data to local disk:
|
| 19 |
+
|
| 20 |
+
mkdir -p /tmp/data/vizwizvqa
|
| 21 |
+
|
| 22 |
+
wget -O https://vizwiz.cs.colorado.edu/VizWiz_final/images/train.zip /tmp/data/vizwizvqa
|
| 23 |
+
wget -O https://vizwiz.cs.colorado.edu/VizWiz_final/images/val.zip /tmp/data/vizwizvqa
|
| 24 |
+
wget -O https://vizwiz.cs.colorado.edu/VizWiz_final/images/test.zip /tmp/data/vizwizvqa
|
| 25 |
+
|
| 26 |
+
Then, run conversion locally
|
| 27 |
+
(make sure to install tensorflow-datasets for the `tfds` util):
|
| 28 |
+
|
| 29 |
+
cd big_vision/datasets
|
| 30 |
+
env TFDS_DATA_DIR=/tmp/tfds tfds build --datasets=vizwizvqa
|
| 31 |
+
|
| 32 |
+
Example to load:
|
| 33 |
+
|
| 34 |
+
import tensorflow_datasets as tfds
|
| 35 |
+
dataset = tfds.load('vizwizvqa', split='train', data_dir='/tmp/tfds')
|
| 36 |
+
"""
|
| 37 |
+
import json
|
| 38 |
+
import os
|
| 39 |
+
|
| 40 |
+
import numpy as np
|
| 41 |
+
import tensorflow_datasets as tfds
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
_DESCRIPTION = """VizWiz VQA Dataset."""
|
| 45 |
+
|
| 46 |
+
# pylint: disable=line-too-long
|
| 47 |
+
_CITATION = """
|
| 48 |
+
@inproceedings{gurari2018vizwiz,
|
| 49 |
+
title={Vizwiz grand challenge: Answering visual questions from blind people},
|
| 50 |
+
author={Gurari, Danna and Li, Qing and Stangl, Abigale J and Guo, Anhong and Lin, Chi and Grauman, Kristen and Luo, Jiebo and Bigham, Jeffrey P},
|
| 51 |
+
booktitle={Proceedings of the IEEE conference on computer vision and pattern recognition},
|
| 52 |
+
pages={3608--3617},
|
| 53 |
+
year={2018}
|
| 54 |
+
}
|
| 55 |
+
}
|
| 56 |
+
"""
|
| 57 |
+
# pylint: enable=line-too-long
|
| 58 |
+
|
| 59 |
+
# When running locally (recommended), copy files as above an use these:
|
| 60 |
+
_VIZWIZVQA_PATH = '/tmp/data/vizwizvqa/'
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class VizWizVQA(tfds.core.GeneratorBasedBuilder):
|
| 64 |
+
"""DatasetBuilder for VizWizVQA dataset."""
|
| 65 |
+
|
| 66 |
+
VERSION = tfds.core.Version('1.0.0')
|
| 67 |
+
RELEASE_NOTES = {'1.0.0': 'First release.'}
|
| 68 |
+
|
| 69 |
+
def _info(self):
|
| 70 |
+
"""Returns the metadata."""
|
| 71 |
+
return tfds.core.DatasetInfo(
|
| 72 |
+
builder=self,
|
| 73 |
+
description=_DESCRIPTION,
|
| 74 |
+
features=tfds.features.FeaturesDict({
|
| 75 |
+
'question': tfds.features.Text(),
|
| 76 |
+
'image/filename': tfds.features.Text(),
|
| 77 |
+
'image': tfds.features.Image(encoding_format='jpeg'),
|
| 78 |
+
'answers': tfds.features.Sequence(tfds.features.Text()),
|
| 79 |
+
# can be "yes" "no" and "maybe" strings
|
| 80 |
+
'answer_confidences': tfds.features.Sequence(tfds.features.Text()),
|
| 81 |
+
'answerable': tfds.features.Scalar(np.int32),
|
| 82 |
+
'question_id': tfds.features.Scalar(np.int32),
|
| 83 |
+
}),
|
| 84 |
+
supervised_keys=None,
|
| 85 |
+
homepage='https://vizwiz.org/tasks-and-datasets/vqa/',
|
| 86 |
+
citation=_CITATION,
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
def _split_generators(self, dl_manager: tfds.download.DownloadManager):
|
| 90 |
+
"""Returns SplitGenerators."""
|
| 91 |
+
return {split: self._generate_examples(split)
|
| 92 |
+
for split in ('val', 'train', 'test',)}
|
| 93 |
+
|
| 94 |
+
def _generate_examples(self, split: str):
|
| 95 |
+
"""Yields (key, example) tuples from test set."""
|
| 96 |
+
annot_fname = os.path.join(_VIZWIZVQA_PATH, 'annotations', f'{split}.json')
|
| 97 |
+
|
| 98 |
+
with open(annot_fname, 'r') as f:
|
| 99 |
+
data = json.loads(f.read())
|
| 100 |
+
|
| 101 |
+
for v in data:
|
| 102 |
+
|
| 103 |
+
answers = []
|
| 104 |
+
answer_confidences = []
|
| 105 |
+
|
| 106 |
+
image_file = v['image']
|
| 107 |
+
answerable = -1
|
| 108 |
+
if split != 'test':
|
| 109 |
+
for answer in v['answers']:
|
| 110 |
+
# A couple of answers in the train set are empty strings.
|
| 111 |
+
if not answer['answer']:
|
| 112 |
+
continue
|
| 113 |
+
answers.append(answer['answer'])
|
| 114 |
+
answer_confidences.append(answer['answer_confidence'])
|
| 115 |
+
answerable = v['answerable']
|
| 116 |
+
|
| 117 |
+
question_id = image_file[:-4]
|
| 118 |
+
question_id = int(question_id.split('_')[-1])
|
| 119 |
+
|
| 120 |
+
yield v['image'], {
|
| 121 |
+
'question': v['question'],
|
| 122 |
+
'image/filename': image_file,
|
| 123 |
+
'question_id': question_id,
|
| 124 |
+
'image': os.path.join(_VIZWIZVQA_PATH, split, image_file),
|
| 125 |
+
'answers': answers,
|
| 126 |
+
'answer_confidences': answer_confidences,
|
| 127 |
+
'answerable': answerable,
|
| 128 |
+
}
|
Tipsomaly/model/big_vision/datasets/vqa/vqa.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# pylint: disable=line-too-long
|
| 16 |
+
r"""Import VQAv2 into TFDS format. Uses coco-2014 images.
|
| 17 |
+
|
| 18 |
+
It's small data, so simple to run locally. First, download all the data:
|
| 19 |
+
|
| 20 |
+
mkdir /tmp/data/ ; cd /tmp/data
|
| 21 |
+
wget http://images.cocodataset.org/zips/{train2014,val2014,test2015}.zip
|
| 22 |
+
wget https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_{Train,Val,Test}_mscoco.zip
|
| 23 |
+
wget https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Annotations_{Train,Val}_mscoco.zip
|
| 24 |
+
unzip '*.zip'
|
| 25 |
+
|
| 26 |
+
Then, run conversion locally (make sure to install tensorflow-datasets for the `tfds` util):
|
| 27 |
+
|
| 28 |
+
cd big_vision/datasets
|
| 29 |
+
env TFDS_DATA_DIR=/tmp/tfds tfds build --datasets=vqa
|
| 30 |
+
|
| 31 |
+
It runs at around 750 examples/sec, so takes around 25min for the 1.2M questions.
|
| 32 |
+
Each question is an example; images are repeated, a bit wasteful, but disk is cheap.
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
Example to load:
|
| 36 |
+
|
| 37 |
+
import tensorflow_datasets as tfds
|
| 38 |
+
dataset = tfds.load('vqa', split='train', data_dir='/tmp/tfds')
|
| 39 |
+
"""
|
| 40 |
+
import json
|
| 41 |
+
import os
|
| 42 |
+
|
| 43 |
+
import numpy as np
|
| 44 |
+
import tensorflow_datasets as tfds
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
_VQAV2_PATH = '/tmp/data'
|
| 48 |
+
_IMAGE_PATH = '/tmp/data'
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
_CITATION = (
|
| 52 |
+
'@InProceedings{balanced_vqa_v2,'
|
| 53 |
+
'author = {Yash Goyal and Tejas Khot and '
|
| 54 |
+
'Douglas Summers{-}Stay and Dhruv Batra and Devi Parikh},'
|
| 55 |
+
'title = {Making the {V} in {VQA} Matter: Elevating the Role of Image'
|
| 56 |
+
'Understanding in {V}isual {Q}uestion {A}nswering},'
|
| 57 |
+
'booktitle = {Computer Vision and Pattern Recognition (CVPR)},'
|
| 58 |
+
'year = {2017},}')
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class Vqa(tfds.core.GeneratorBasedBuilder):
|
| 62 |
+
"""DatasetBuilder for VQAv2 dataset."""
|
| 63 |
+
|
| 64 |
+
VERSION = tfds.core.Version('3.0.0')
|
| 65 |
+
RELEASE_NOTES = {'3.0.0': 'Format as needed for PaliGemma'}
|
| 66 |
+
|
| 67 |
+
def _info(self) -> tfds.core.DatasetInfo:
|
| 68 |
+
"""Returns the metadata."""
|
| 69 |
+
|
| 70 |
+
return tfds.core.DatasetInfo(
|
| 71 |
+
builder=self,
|
| 72 |
+
description='The VQAv2 dataset.',
|
| 73 |
+
features=tfds.features.FeaturesDict({
|
| 74 |
+
'image/id': np.int32,
|
| 75 |
+
'image/filename': tfds.features.Text(),
|
| 76 |
+
'image': tfds.features.Image(encoding_format='jpeg'),
|
| 77 |
+
'question_id': np.int32,
|
| 78 |
+
'question_type': tfds.features.Text(),
|
| 79 |
+
'question_text': tfds.features.Text(),
|
| 80 |
+
'answer_type': tfds.features.Text(),
|
| 81 |
+
'answers': tfds.features.Sequence(tfds.features.Text()),
|
| 82 |
+
'answer_confidences': tfds.features.Sequence(
|
| 83 |
+
tfds.features.ClassLabel(names=['no', 'maybe', 'yes'])),
|
| 84 |
+
'top_answer': tfds.features.Text(),
|
| 85 |
+
}),
|
| 86 |
+
homepage='https://visualqa.org/',
|
| 87 |
+
citation=_CITATION,
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
def _split_generators(self, dl_manager: tfds.download.DownloadManager):
|
| 91 |
+
"""Returns SplitGenerators."""
|
| 92 |
+
return {
|
| 93 |
+
'train': self._generate_examples('train2014'),
|
| 94 |
+
'validation': self._generate_examples('val2014'),
|
| 95 |
+
'test': self._generate_examples('test2015'),
|
| 96 |
+
'test-dev': self._generate_examples('test-dev2015', 'test2015'),
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
def _generate_examples(self, split, image_folder=None):
|
| 100 |
+
"""Yields (key, example) tuples from test set."""
|
| 101 |
+
image_folder = image_folder or split
|
| 102 |
+
|
| 103 |
+
# The questions file has fields image_id, question, question_id.
|
| 104 |
+
with open(os.path.join(
|
| 105 |
+
_VQAV2_PATH, f'v2_OpenEnded_mscoco_{split}_questions.json')) as f:
|
| 106 |
+
examples = json.load(f)['questions']
|
| 107 |
+
|
| 108 |
+
# The questions file has fields: image_id, question_id, answers,
|
| 109 |
+
# answer_type, question_type, multiple_choice_answer.
|
| 110 |
+
if 'test' not in split:
|
| 111 |
+
with open(os.path.join(
|
| 112 |
+
_VQAV2_PATH, f'v2_mscoco_{split}_annotations.json')) as f:
|
| 113 |
+
annots = {a['question_id']: a for a in json.load(f)['annotations']}
|
| 114 |
+
|
| 115 |
+
for ex in examples:
|
| 116 |
+
qid = ex['question_id']
|
| 117 |
+
ex = {
|
| 118 |
+
'image/id': ex['image_id'],
|
| 119 |
+
'question_id': qid,
|
| 120 |
+
'question_text': ex['question'],
|
| 121 |
+
}
|
| 122 |
+
if 'test' not in split:
|
| 123 |
+
fname = f'COCO_{image_folder}_{ex["image/id"]:012d}.jpg'
|
| 124 |
+
ex['image/filename'] = fname
|
| 125 |
+
ex['image'] = os.path.join(_IMAGE_PATH, image_folder, fname)
|
| 126 |
+
ann = annots[qid]
|
| 127 |
+
ex['question_type'] = ann['question_type']
|
| 128 |
+
ex['answer_type'] = ann['answer_type']
|
| 129 |
+
ex['answers'] = [a['answer'] for a in ann['answers']]
|
| 130 |
+
ex['answer_confidences'] = [a['answer_confidence']
|
| 131 |
+
for a in ann['answers']]
|
| 132 |
+
ex['top_answer'] = ann['multiple_choice_answer']
|
| 133 |
+
else:
|
| 134 |
+
# For test images, a few are from the wrong year...
|
| 135 |
+
fname = f'COCO_{image_folder}_{ex["image/id"]:012d}.jpg'
|
| 136 |
+
ex['image/filename'] = fname
|
| 137 |
+
if os.path.isfile(path := os.path.join(_IMAGE_PATH, image_folder, fname)):
|
| 138 |
+
ex['image'] = path
|
| 139 |
+
else:
|
| 140 |
+
print(ex['image/id'])
|
| 141 |
+
continue
|
| 142 |
+
ex['question_type'] = ''
|
| 143 |
+
ex['answer_type'] = ''
|
| 144 |
+
ex['answers'] = []
|
| 145 |
+
ex['answer_confidences'] = []
|
| 146 |
+
ex['top_answer'] = ''
|
| 147 |
+
yield qid, ex
|
Tipsomaly/model/big_vision/datasets/widgetcap/widgetcap.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# pylint: disable=line-too-long
|
| 16 |
+
r"""Import widgetcap into TFDS format.
|
| 17 |
+
|
| 18 |
+
Widget Captioning all requires images from the RICO dataset:
|
| 19 |
+
mkdir -p /tmp/data/rico_images ; cd /tmp/data/rico_images
|
| 20 |
+
wget
|
| 21 |
+
https://storage.googleapis.com/crowdstf-rico-uiuc-4540/rico_dataset_v0.1/unique_uis.tar.gz
|
| 22 |
+
tar xvfz unique_uis.tar.gz
|
| 23 |
+
rm unique_uis.tar.gz
|
| 24 |
+
|
| 25 |
+
Widget Captioning:
|
| 26 |
+
mkdir - /tmp/data/widget_captioning ; cd /tmp/data/widget_captioning
|
| 27 |
+
git clone https://github.com/google-research-datasets/widget-caption.git
|
| 28 |
+
cp widget-caption/widget_captions.csv ./
|
| 29 |
+
cp widget-caption/split/*.txt ./
|
| 30 |
+
rm -rf widget-caption
|
| 31 |
+
|
| 32 |
+
Then, run conversion locally (make sure to install tensorflow-datasets for the
|
| 33 |
+
`tfds` util):
|
| 34 |
+
|
| 35 |
+
cd big_vision/datasets
|
| 36 |
+
env TFDS_DATA_DIR=/tmp/tfds tfds build --datasets=widgetcap
|
| 37 |
+
|
| 38 |
+
Example to load:
|
| 39 |
+
|
| 40 |
+
import tensorflow_datasets as tfds
|
| 41 |
+
dataset_augmented = tfds.load('widgetcap', split='train',
|
| 42 |
+
data_dir='/tmp/tfds')
|
| 43 |
+
"""
|
| 44 |
+
import csv
|
| 45 |
+
import json
|
| 46 |
+
import os
|
| 47 |
+
|
| 48 |
+
import numpy as np
|
| 49 |
+
from PIL import Image
|
| 50 |
+
import tensorflow_datasets as tfds
|
| 51 |
+
|
| 52 |
+
_DATASET_DIR = '/tmp/data/widget_captioning'
|
| 53 |
+
# Dataset property indicating the y-dim of the canvas
|
| 54 |
+
_RICO_CANVAS_Y = 2560
|
| 55 |
+
_IMAGE_DIR = '/tmp/data/rico_images/combined'
|
| 56 |
+
|
| 57 |
+
_CITATION = (
|
| 58 |
+
'@inproceedings{Li2020WidgetCG,title={Widget Captioning: Generating Natural'
|
| 59 |
+
' Language Description for MobileUser Interface Elements},author={Y. Li and'
|
| 60 |
+
' Gang Li and Luheng He and Jingjie Zheng and Hong Li and Zhiwei'
|
| 61 |
+
' Guan},booktitle={Conference on Empirical Methods in Natural Language'
|
| 62 |
+
' Processing},year={2020},}'
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class Widgetcap(tfds.core.GeneratorBasedBuilder):
|
| 67 |
+
"""DatasetBuilder for widgetcap dataset."""
|
| 68 |
+
|
| 69 |
+
VERSION = tfds.core.Version('1.0.0')
|
| 70 |
+
RELEASE_NOTES = {'1.0.0': 'Format as needed for PaliGemma'}
|
| 71 |
+
|
| 72 |
+
def _info(self) -> tfds.core.DatasetInfo:
|
| 73 |
+
"""Returns the metadata."""
|
| 74 |
+
|
| 75 |
+
return tfds.core.DatasetInfo(
|
| 76 |
+
builder=self,
|
| 77 |
+
description='The widgetcap dataset.',
|
| 78 |
+
features=tfds.features.FeaturesDict({
|
| 79 |
+
'image/id': tfds.features.Text(),
|
| 80 |
+
'image/filename': tfds.features.Text(),
|
| 81 |
+
'image': tfds.features.Image(encoding_format='jpeg'),
|
| 82 |
+
'texts': tfds.features.Sequence(tfds.features.Text()),
|
| 83 |
+
'bbox': tfds.features.BBoxFeature(),
|
| 84 |
+
'screen_id': tfds.features.Text(),
|
| 85 |
+
'node_id': tfds.features.Text(),
|
| 86 |
+
'height': np.int32,
|
| 87 |
+
'width': np.int32,
|
| 88 |
+
}),
|
| 89 |
+
homepage='https://github.com/google-research-datasets/widget-caption',
|
| 90 |
+
citation=_CITATION,
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
def _split_generators(self, dl_manager: tfds.download.DownloadManager):
|
| 94 |
+
"""Returns SplitGenerators."""
|
| 95 |
+
return {
|
| 96 |
+
'train': self._generate_examples('train'),
|
| 97 |
+
'dev': self._generate_examples('dev'),
|
| 98 |
+
'test': self._generate_examples('test'),
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
def _generate_examples(self, split):
|
| 102 |
+
"""Yields (key, example) tuples from the dataset."""
|
| 103 |
+
split_screen_ids = set()
|
| 104 |
+
with open(os.path.join(_DATASET_DIR, split + '.txt')) as f:
|
| 105 |
+
for line in f:
|
| 106 |
+
split_screen_ids.add(line.strip())
|
| 107 |
+
|
| 108 |
+
with open(os.path.join(_DATASET_DIR, 'widget_captions.csv')) as f:
|
| 109 |
+
for row in csv.DictReader(f):
|
| 110 |
+
if row['screenId'] in split_screen_ids:
|
| 111 |
+
id_, example = self._get_example(
|
| 112 |
+
row['screenId'], row['nodeId'], row['captions']
|
| 113 |
+
)
|
| 114 |
+
yield id_, example
|
| 115 |
+
|
| 116 |
+
def _get_node_box(self, screen_id, node_id, height):
|
| 117 |
+
index_list = [int(i) for i in node_id.split('.')[1:]]
|
| 118 |
+
with open(os.path.join(_IMAGE_DIR, screen_id + '.json')) as f:
|
| 119 |
+
view = json.load(f)
|
| 120 |
+
curr_node = view['activity']['root']
|
| 121 |
+
for index in index_list:
|
| 122 |
+
curr_node = curr_node['children'][index]
|
| 123 |
+
normalized_bounds = map(
|
| 124 |
+
lambda x: x * height / _RICO_CANVAS_Y, curr_node['bounds']
|
| 125 |
+
)
|
| 126 |
+
return normalized_bounds
|
| 127 |
+
|
| 128 |
+
def _get_example(self, screen_id, node_id, captions):
|
| 129 |
+
image = Image.open(os.path.join(_IMAGE_DIR, screen_id + '.jpg'))
|
| 130 |
+
width, height = image.size
|
| 131 |
+
# get bounding box coordinates
|
| 132 |
+
xmin, ymin, xmax, ymax = self._get_node_box(screen_id, node_id, height)
|
| 133 |
+
|
| 134 |
+
image_id = f'{screen_id}_{node_id}'
|
| 135 |
+
example = {
|
| 136 |
+
'image/id': image_id,
|
| 137 |
+
'image/filename': screen_id + '.jpg',
|
| 138 |
+
'image': os.path.join(_IMAGE_DIR, screen_id + '.jpg'),
|
| 139 |
+
'texts': captions.split('|'),
|
| 140 |
+
'bbox': tfds.features.BBox(
|
| 141 |
+
ymin=ymin / height,
|
| 142 |
+
xmin=xmin / width,
|
| 143 |
+
ymax=ymax / height,
|
| 144 |
+
xmax=xmax / width,
|
| 145 |
+
),
|
| 146 |
+
'screen_id': screen_id,
|
| 147 |
+
'node_id': node_id,
|
| 148 |
+
'height': height,
|
| 149 |
+
'width': width,
|
| 150 |
+
}
|
| 151 |
+
return image_id, example
|
Tipsomaly/model/big_vision/datasets/xgqa/xgqa.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# pylint: disable=line-too-long
|
| 16 |
+
"""Generates xGQA in a TFDS-ready structure.
|
| 17 |
+
|
| 18 |
+
First, download the data:
|
| 19 |
+
mkdir -p /tmp/data/xgqa/annotations
|
| 20 |
+
wget https://raw.githubusercontent.com/e-bug/iglue/main/datasets/xGQA/annotations/zero_shot/testdev_balanced_questions_bn.json -P /tmp/data/xgqa/annotations
|
| 21 |
+
wget https://raw.githubusercontent.com/e-bug/iglue/main/datasets/xGQA/annotations/zero_shot/testdev_balanced_questions_de.json -P /tmp/data/xgqa/annotations
|
| 22 |
+
wget https://raw.githubusercontent.com/e-bug/iglue/main/datasets/xGQA/annotations/zero_shot/testdev_balanced_questions_en.json -P /tmp/data/xgqa/annotations
|
| 23 |
+
wget https://raw.githubusercontent.com/e-bug/iglue/main/datasets/xGQA/annotations/zero_shot/testdev_balanced_questions_id.json -P /tmp/data/xgqa/annotations
|
| 24 |
+
wget https://raw.githubusercontent.com/e-bug/iglue/main/datasets/xGQA/annotations/zero_shot/testdev_balanced_questions_ko.json -P /tmp/data/xgqa/annotations
|
| 25 |
+
wget https://raw.githubusercontent.com/e-bug/iglue/main/datasets/xGQA/annotations/zero_shot/testdev_balanced_questions_pt.json -P /tmp/data/xgqa/annotations
|
| 26 |
+
wget https://raw.githubusercontent.com/e-bug/iglue/main/datasets/xGQA/annotations/zero_shot/testdev_balanced_questions_ru.json -P /tmp/data/xgqa/annotations
|
| 27 |
+
wget https://raw.githubusercontent.com/e-bug/iglue/main/datasets/xGQA/annotations/zero_shot/testdev_balanced_questions_zh.json -P /tmp/data/xgqa/annotations
|
| 28 |
+
wget https://downloads.cs.stanford.edu/nlp/data/gqa/images.zip -P /tmp/data/xgqa/
|
| 29 |
+
unzip /tmp/data/xgqa/images.zip -d /tmp/data/xgqa/
|
| 30 |
+
|
| 31 |
+
Then, run conversion locally (make sure to install tensorflow-datasets for the `tfds` util):
|
| 32 |
+
|
| 33 |
+
cd big_vision/datasets
|
| 34 |
+
env TFDS_DATA_DIR=/tmp/tfds tfds build --datasets=xgqa
|
| 35 |
+
|
| 36 |
+
Example to load:
|
| 37 |
+
|
| 38 |
+
import tensorflow_datasets as tfds
|
| 39 |
+
dataset = tfds.load(
|
| 40 |
+
'xgqa', split='test_zs_en',
|
| 41 |
+
data_dir='/tmp/tfds')
|
| 42 |
+
"""
|
| 43 |
+
import json
|
| 44 |
+
import os
|
| 45 |
+
|
| 46 |
+
import tensorflow_datasets as tfds
|
| 47 |
+
|
| 48 |
+
_DESCRIPTION = """xGQA (uses GQA images)."""
|
| 49 |
+
|
| 50 |
+
# pylint: disable=line-too-long
|
| 51 |
+
_CITATION = (
|
| 52 |
+
'@inproceedings{pfeiffer-etal-2022-xgqa,'
|
| 53 |
+
'title = "x{GQA}: Cross-Lingual Visual Question Answering",'
|
| 54 |
+
'author = "Pfeiffer, Jonas and'
|
| 55 |
+
' Geigle, Gregor and'
|
| 56 |
+
' Kamath, Aishwarya and'
|
| 57 |
+
' Steitz, Jan-Martin and'
|
| 58 |
+
' Roth, Stefan and'
|
| 59 |
+
' Vuli{\'c}, Ivan and'
|
| 60 |
+
' Gurevych, Iryna",'
|
| 61 |
+
'booktitle = "Findings of the Association for Computational Linguistics: '
|
| 62 |
+
'ACL 2022",'
|
| 63 |
+
'month = may,'
|
| 64 |
+
'year = "2022",'
|
| 65 |
+
'address = "Dublin, Ireland",'
|
| 66 |
+
'publisher = "Association for Computational Linguistics",'
|
| 67 |
+
'url = "https://aclanthology.org/2022.findings-acl.196",'
|
| 68 |
+
'doi = "10.18653/v1/2022.findings-acl.196",'
|
| 69 |
+
'pages = "2497--2511",'
|
| 70 |
+
'}'
|
| 71 |
+
)
|
| 72 |
+
# pylint: enable=line-too-long
|
| 73 |
+
|
| 74 |
+
# When running locally (recommended), copy files as above an use these:
|
| 75 |
+
_DATA_PATH = '/tmp/data/xgqa/'
|
| 76 |
+
_IMAGE_PATH = '/tmp/data/xgqa/images/'
|
| 77 |
+
|
| 78 |
+
LANGUAGES = frozenset(['bn', 'de', 'en', 'id', 'ko', 'pt', 'ru', 'zh'])
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class XGQA(tfds.core.GeneratorBasedBuilder):
|
| 82 |
+
"""DatasetBuilder for XGQA dataset."""
|
| 83 |
+
|
| 84 |
+
VERSION = tfds.core.Version('1.0.0')
|
| 85 |
+
RELEASE_NOTES = {'1.0.0': 'First release.'}
|
| 86 |
+
|
| 87 |
+
def _info(self):
|
| 88 |
+
"""Returns the metadata."""
|
| 89 |
+
|
| 90 |
+
return tfds.core.DatasetInfo(
|
| 91 |
+
builder=self,
|
| 92 |
+
description=_DESCRIPTION,
|
| 93 |
+
features=tfds.features.FeaturesDict({
|
| 94 |
+
'example_id': tfds.features.Text(),
|
| 95 |
+
'image/id': tfds.features.Text(),
|
| 96 |
+
'image': tfds.features.Image(encoding_format='jpeg'),
|
| 97 |
+
'question': tfds.features.Text(),
|
| 98 |
+
'answer': tfds.features.Text(),
|
| 99 |
+
}),
|
| 100 |
+
supervised_keys=None,
|
| 101 |
+
homepage='https://github.com/adapter-hub/xGQA',
|
| 102 |
+
citation=_CITATION,
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
def _split_generators(self, dl_manager: tfds.download.DownloadManager):
|
| 106 |
+
"""Returns SplitGenerators."""
|
| 107 |
+
d = dict()
|
| 108 |
+
for l in LANGUAGES:
|
| 109 |
+
d.update({
|
| 110 |
+
f'test_zs_{l}': self._generate_examples('test', 'zero_shot', l),
|
| 111 |
+
f'test_fs_{l}': self._generate_examples('test', 'few_shot', l),
|
| 112 |
+
f'dev_fs_{l}': self._generate_examples('test', 'few_shot', l),
|
| 113 |
+
f'train_fs1_{l}': self._generate_examples('train_1', 'few_shot', l),
|
| 114 |
+
f'train_fs5_{l}': self._generate_examples('train_5', 'few_shot', l),
|
| 115 |
+
f'train_fs10_{l}': self._generate_examples('train_10', 'few_shot', l),
|
| 116 |
+
f'train_fs20_{l}': self._generate_examples('train_20', 'few_shot', l),
|
| 117 |
+
f'train_fs25_{l}': self._generate_examples('train_25', 'few_shot', l),
|
| 118 |
+
f'train_fs48_{l}': self._generate_examples('train_48', 'few_shot', l),
|
| 119 |
+
})
|
| 120 |
+
return d
|
| 121 |
+
|
| 122 |
+
def _generate_examples(self, split, num_shots, lang):
|
| 123 |
+
"""Yields (key, example) tuples."""
|
| 124 |
+
# Loads the questions for each image.
|
| 125 |
+
if num_shots == 'few_shot':
|
| 126 |
+
file_path = os.path.join(_DATA_PATH, 'annotations', 'few_shot', lang,
|
| 127 |
+
f'{split}.json')
|
| 128 |
+
elif num_shots == 'zero_shot':
|
| 129 |
+
file_path = os.path.join(_DATA_PATH, 'annotations', 'zero_shot',
|
| 130 |
+
f'testdev_balanced_questions_{lang}.json')
|
| 131 |
+
else:
|
| 132 |
+
raise ValueError(f'Unknown num_shots: {num_shots}')
|
| 133 |
+
with open(file_path, 'r') as f:
|
| 134 |
+
entries = json.load(f)
|
| 135 |
+
|
| 136 |
+
# Make one entry per question-answer pair.
|
| 137 |
+
for question_id, question_data in entries.items():
|
| 138 |
+
example_id = f'{question_id}_{lang}'
|
| 139 |
+
yield example_id, {
|
| 140 |
+
'example_id': example_id,
|
| 141 |
+
'image/id': question_data['imageId'],
|
| 142 |
+
'image': os.path.join(_IMAGE_PATH, f'{question_data["imageId"]}.jpg'),
|
| 143 |
+
'question': question_data['question'],
|
| 144 |
+
'answer': question_data['answer'],
|
| 145 |
+
}
|
Tipsomaly/model/big_vision/datasets/xm3600/xm3600.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# pylint: disable=line-too-long
|
| 16 |
+
r"""Generates XM3600 in a TFDS-ready structure.
|
| 17 |
+
|
| 18 |
+
First, download the captions from https://google.github.io/crossmodal-3600/ and the images from https://cocodataset.org/#download.
|
| 19 |
+
The coco Karpathy split is available at http://cs.stanford.edu/people/karpathy/deepimagesent/caption_datasets.zip:
|
| 20 |
+
mkdir -p /tmp/data/xm3600
|
| 21 |
+
wget https://google.github.io/crossmodal-3600/web-data/captions.zip -P /tmp/data/xm3600
|
| 22 |
+
unzip /tmp/data/xm3600/captions.zip -d /tmp/data/xm3600/
|
| 23 |
+
wget https://open-images-dataset.s3.amazonaws.com/crossmodal-3600/images.tgz ta-P /tmp/data/xm3600
|
| 24 |
+
mkdir /tmp/data/xm3600/images
|
| 25 |
+
tar -xzf /tmp/data/xm3600/images.tgz -C /tmp/data/xm3600/images
|
| 26 |
+
|
| 27 |
+
Then, run conversion locally (make sure to install tensorflow-datasets for the `tfds` util):
|
| 28 |
+
|
| 29 |
+
cd big_vision/datasets
|
| 30 |
+
env TFDS_DATA_DIR=/tmp/tfds tfds build --datasets=xm3600
|
| 31 |
+
|
| 32 |
+
Example to load:
|
| 33 |
+
|
| 34 |
+
import tensorflow_datasets as tfds
|
| 35 |
+
dataset = tfds.load(
|
| 36 |
+
'xm3600', split='en',
|
| 37 |
+
data_dir='/tmp/tfds')
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
import json
|
| 41 |
+
import os.path
|
| 42 |
+
|
| 43 |
+
import tensorflow_datasets as tfds
|
| 44 |
+
|
| 45 |
+
_DESCRIPTION = """
|
| 46 |
+
COCO image + captions, translated from English to 35 languages (English incl.).
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
# pylint: disable=line-too-long
|
| 50 |
+
_CITATION = """
|
| 51 |
+
@inproceedings{thapliyal-etal-2022-crossmodal,
|
| 52 |
+
title = "Crossmodal-3600: A Massively Multilingual Multimodal Evaluation Dataset",
|
| 53 |
+
author = "Thapliyal, Ashish V. and
|
| 54 |
+
Pont Tuset, Jordi and
|
| 55 |
+
Chen, Xi and
|
| 56 |
+
Soricut, Radu",
|
| 57 |
+
editor = "Goldberg, Yoav and
|
| 58 |
+
Kozareva, Zornitsa and
|
| 59 |
+
Zhang, Yue",
|
| 60 |
+
booktitle = "Proceedings of the 2022 Conference on Empirical Methods in Natural Language Processing",
|
| 61 |
+
month = dec,
|
| 62 |
+
year = "2022",
|
| 63 |
+
address = "Abu Dhabi, United Arab Emirates",
|
| 64 |
+
publisher = "Association for Computational Linguistics",
|
| 65 |
+
url = "https://aclanthology.org/2022.emnlp-main.45",
|
| 66 |
+
doi = "10.18653/v1/2022.emnlp-main.45",
|
| 67 |
+
pages = "715--729",
|
| 68 |
+
}
|
| 69 |
+
"""
|
| 70 |
+
# pylint: enable=line-too-long
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
_CAPTIONS_PATH = '/tmp/data/xm3600'
|
| 74 |
+
_IMAGES_PATH = '/tmp/data/xm3600/images'
|
| 75 |
+
|
| 76 |
+
XM3600_LANGUAGES = [
|
| 77 |
+
'ar', 'bn', 'cs', 'da', 'de', 'el', 'en', 'es', 'fa', 'fi', 'fil', 'fr',
|
| 78 |
+
'he', 'hi', 'hr', 'hu', 'id', 'it', 'ja', 'ko', 'mi', 'nl', 'no', 'pl',
|
| 79 |
+
'pt', 'quz', 'ro', 'ru', 'sv', 'sw', 'te', 'th', 'tr', 'uk', 'vi', 'zh'
|
| 80 |
+
]
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class Xm3600(tfds.core.GeneratorBasedBuilder):
|
| 84 |
+
"""DatasetBuilder for XM3600 dataset."""
|
| 85 |
+
|
| 86 |
+
VERSION = tfds.core.Version('1.0.1')
|
| 87 |
+
RELEASE_NOTES = {
|
| 88 |
+
'1.0.0': 'First release.',
|
| 89 |
+
'1.0.1': 'Add captions/tokenized feature to compute metrics (eg CIDEr).',
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
def _info(self):
|
| 93 |
+
"""Returns the metadata."""
|
| 94 |
+
|
| 95 |
+
return tfds.core.DatasetInfo(
|
| 96 |
+
builder=self,
|
| 97 |
+
description=_DESCRIPTION,
|
| 98 |
+
features=tfds.features.FeaturesDict({
|
| 99 |
+
'image/id': tfds.features.Text(),
|
| 100 |
+
'image': tfds.features.Image(encoding_format='jpeg'),
|
| 101 |
+
'captions': tfds.features.Sequence(tfds.features.Text()),
|
| 102 |
+
'captions/tokenized': tfds.features.Sequence(tfds.features.Text()),
|
| 103 |
+
'language': tfds.features.Text(),
|
| 104 |
+
}),
|
| 105 |
+
supervised_keys=None,
|
| 106 |
+
homepage='https://google.github.io/crossmodal-3600/',
|
| 107 |
+
citation=_CITATION,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
def _split_generators(self, dl_manager: tfds.download.DownloadManager):
|
| 111 |
+
"""Returns SplitGenerators."""
|
| 112 |
+
return {lang: self._generate_examples(lang) for lang in XM3600_LANGUAGES}
|
| 113 |
+
|
| 114 |
+
def _generate_examples(self, split: str):
|
| 115 |
+
"""Yields (key, example) tuples from dataset."""
|
| 116 |
+
language = split
|
| 117 |
+
|
| 118 |
+
annot_fname = os.path.join(_CAPTIONS_PATH, 'captions.jsonl')
|
| 119 |
+
data = {}
|
| 120 |
+
tok_data = {}
|
| 121 |
+
with open(annot_fname, 'r') as f:
|
| 122 |
+
for line in f:
|
| 123 |
+
j = json.loads(line)
|
| 124 |
+
image_id = f'{j["image/key"]}_{language}'
|
| 125 |
+
captions = j[language]['caption']
|
| 126 |
+
data[image_id] = captions
|
| 127 |
+
tok_data[image_id] = j[language]['caption/tokenized']
|
| 128 |
+
|
| 129 |
+
for image_id, captions in data.items():
|
| 130 |
+
yield image_id, {
|
| 131 |
+
'image/id': image_id,
|
| 132 |
+
'image': os.path.join(_IMAGES_PATH, f'{image_id.split("_")[0]}.jpg'),
|
| 133 |
+
'captions': captions,
|
| 134 |
+
'captions/tokenized': tok_data[image_id],
|
| 135 |
+
'language': language,
|
| 136 |
+
}
|
Tipsomaly/model/big_vision/evaluators/proj/cappa/perplexity.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Evaluator for perplexity of a model."""
|
| 16 |
+
from big_vision.evaluators import mean
|
| 17 |
+
import big_vision.utils as u
|
| 18 |
+
import jax.numpy as jnp
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# Temporary global flag to facilitate backwards compatability. Will be removed
|
| 22 |
+
# by the end of year 2023.
|
| 23 |
+
API = 'jit'
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def perplexity(predict_fn, normalize_by_seqlen):
|
| 27 |
+
"""Returns a function that computes perplexity."""
|
| 28 |
+
|
| 29 |
+
def _perplexity_fn(train_state, batch, pad_token=0, **kw):
|
| 30 |
+
logits, _ = predict_fn(train_state, batch, **kw)
|
| 31 |
+
|
| 32 |
+
# Ignore perplexity on the padding label.
|
| 33 |
+
weights = jnp.where(batch['labels'] != pad_token, 1, 0).astype(jnp.float32)
|
| 34 |
+
if batch.get('label_masks') is not None:
|
| 35 |
+
weights = weights * batch['label_masks']
|
| 36 |
+
|
| 37 |
+
losses = u.weighted_softmax_xent(
|
| 38 |
+
logits=logits, labels=batch['labels'],
|
| 39 |
+
weights=weights, label_smoothing=0.0,
|
| 40 |
+
reduction=False, normalize=normalize_by_seqlen)
|
| 41 |
+
|
| 42 |
+
return {'perplexity': losses}
|
| 43 |
+
return _perplexity_fn
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class Evaluator(mean.Evaluator):
|
| 47 |
+
"""Perplexity evaluator."""
|
| 48 |
+
|
| 49 |
+
def __init__(self, predict_fn, *a, normalize_by_seqlen=False, **kw):
|
| 50 |
+
super().__init__(perplexity(predict_fn, normalize_by_seqlen), *a, **kw)
|
Tipsomaly/model/big_vision/evaluators/proj/cappa/scoring_classifier.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Scoring classifier.
|
| 16 |
+
|
| 17 |
+
This one is based on a generative perspective for image classification.
|
| 18 |
+
Here we input the image as well as all the tokenized labels to compute their
|
| 19 |
+
perplexity and select the one with minimum loss as the prediction.
|
| 20 |
+
"""
|
| 21 |
+
import functools
|
| 22 |
+
from big_vision.datasets.imagenet import class_names as imagenet_class_names
|
| 23 |
+
from big_vision.evaluators import mean
|
| 24 |
+
from big_vision.pp import builder as pp_builder
|
| 25 |
+
import jax.numpy as jnp
|
| 26 |
+
import numpy as np
|
| 27 |
+
|
| 28 |
+
# Temporary global flag to facilitate backwards compatability. Will be removed
|
| 29 |
+
# by the end of year 2023.
|
| 30 |
+
API = "jit"
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
CLASS_NAMES = {
|
| 34 |
+
"imagenet2012": imagenet_class_names.CLIP_IMAGENET_CLASS_NAMES,
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# As a separate function to cache result across instances.
|
| 39 |
+
@functools.lru_cache(maxsize=None)
|
| 40 |
+
def get_classes(dataset_name, pp_txt):
|
| 41 |
+
"""Load the class label strings and tokenize them using pp_txt."""
|
| 42 |
+
pp_fn = pp_builder.get_preprocess_fn(pp_txt, log_data=False)
|
| 43 |
+
return np.array([pp_fn({"label": name})["labels"]
|
| 44 |
+
for name in CLASS_NAMES[dataset_name]])
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def scoring(predict_fn, tokenized_labels):
|
| 48 |
+
|
| 49 |
+
def _scoring_fn(train_state, batch, *a, **kw):
|
| 50 |
+
batch = {"_label_tokens": tokenized_labels, **batch}
|
| 51 |
+
scores = predict_fn(train_state, batch, *a, **kw)
|
| 52 |
+
predictions = jnp.argmax(scores, axis=-1)
|
| 53 |
+
return {"prec@1": predictions == batch["label"]}
|
| 54 |
+
|
| 55 |
+
return _scoring_fn
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class Evaluator(mean.Evaluator):
|
| 59 |
+
"""Evaluator for classification accuracy based on scoring all classes."""
|
| 60 |
+
|
| 61 |
+
def __init__(self, predict_fn, data, pp_fn, pp_txt, *a, **kw):
|
| 62 |
+
cls_tokens = get_classes(data["name"], pp_txt)
|
| 63 |
+
super().__init__(scoring(predict_fn, cls_tokens), data, pp_fn, *a, **kw)
|
Tipsomaly/model/big_vision/evaluators/proj/distill/distance.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Evaluator for the classfication task."""
|
| 16 |
+
from functools import partial, lru_cache
|
| 17 |
+
|
| 18 |
+
from big_vision import input_pipeline
|
| 19 |
+
import big_vision.datasets.core as ds_core
|
| 20 |
+
import big_vision.pp.builder as pp_builder
|
| 21 |
+
import big_vision.utils as u
|
| 22 |
+
|
| 23 |
+
import einops
|
| 24 |
+
import jax
|
| 25 |
+
import jax.numpy as jnp
|
| 26 |
+
from jax.sharding import NamedSharding
|
| 27 |
+
from jax.sharding import PartitionSpec as P
|
| 28 |
+
import numpy as np
|
| 29 |
+
|
| 30 |
+
# Temporary global flag to facilitate backwards compatability. Will be removed
|
| 31 |
+
# by the end of year 2023.
|
| 32 |
+
API = 'jit'
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def dist(student, teacher, kind, feat_axis=-1,
|
| 36 |
+
epsilon=1e-12, t=1, ls=0.0, k=1):
|
| 37 |
+
"""Distance function used for distillation."""
|
| 38 |
+
diff = student - teacher
|
| 39 |
+
if kind == 'euclidean':
|
| 40 |
+
return jnp.sqrt(jnp.sum(diff * diff, axis=feat_axis) + epsilon)
|
| 41 |
+
elif kind == 'l2':
|
| 42 |
+
return jnp.sum(diff * diff, axis=feat_axis)
|
| 43 |
+
elif kind == 'hard':
|
| 44 |
+
pseudolabels = jnp.argmax(teacher, feat_axis)
|
| 45 |
+
pl = u.onehot(pseudolabels, teacher.shape[feat_axis])
|
| 46 |
+
if ls:
|
| 47 |
+
pl = (1.0 - ls) * pl + (ls / (pl.shape[-1] - 1)) * (1.0 - pl)
|
| 48 |
+
return u.softmax_xent(logits=student, labels=pl,
|
| 49 |
+
reduction=False, kl=True, axis=feat_axis)
|
| 50 |
+
elif kind == 'kl':
|
| 51 |
+
return t**2 * u.softmax_xent(
|
| 52 |
+
logits=student / t,
|
| 53 |
+
labels=jax.nn.softmax(teacher / t),
|
| 54 |
+
reduction=False, kl=True, axis=feat_axis)
|
| 55 |
+
elif kind == 'logsoftmax_euclidean':
|
| 56 |
+
logsoftmax_diff = (
|
| 57 |
+
jax.nn.log_softmax(student, axis=feat_axis) -
|
| 58 |
+
jax.nn.log_softmax(teacher, axis=feat_axis))
|
| 59 |
+
return jnp.sqrt(
|
| 60 |
+
jnp.sum(logsoftmax_diff * logsoftmax_diff, axis=feat_axis) + epsilon)
|
| 61 |
+
elif kind == 'agree':
|
| 62 |
+
def get_top_k(arr, k, ax):
|
| 63 |
+
return jax.lax.top_k(arr.swapaxes(ax, -1), k)[1].swapaxes(ax, -1)
|
| 64 |
+
return (get_top_k(student, k, feat_axis) ==
|
| 65 |
+
get_top_k(teacher, 1, feat_axis)).sum(feat_axis)
|
| 66 |
+
else:
|
| 67 |
+
assert False, f'Unknown kind of distance {kind}.'
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
@lru_cache(None)
|
| 71 |
+
def get_dist_fn(**kw):
|
| 72 |
+
return partial(dist, **kw)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
# To avoid re-compiling the function for every new instance of the same
|
| 76 |
+
# evaluator on a different dataset!
|
| 77 |
+
@lru_cache(None)
|
| 78 |
+
def get_eval_fn(student_teacher_fwd, what, mesh, distances):
|
| 79 |
+
"""Produces eval function, also applies pmap."""
|
| 80 |
+
@partial(jax.jit, out_shardings=NamedSharding(mesh, P()))
|
| 81 |
+
def _eval_fn(train_state, batch, mask):
|
| 82 |
+
(_, out_s), (_, out_t) = student_teacher_fwd(train_state, batch)
|
| 83 |
+
repr_s = u.tree_get(out_s, what[0])
|
| 84 |
+
repr_t = u.tree_get(out_t, what[1])
|
| 85 |
+
|
| 86 |
+
# Let's flatten any non-vectors (eg feature-maps).
|
| 87 |
+
repr_s = einops.rearrange(repr_s, 'b ... -> b (...)')
|
| 88 |
+
repr_t = einops.rearrange(repr_t, 'b ... -> b (...)')
|
| 89 |
+
|
| 90 |
+
all_ds = []
|
| 91 |
+
# NOTE: we're gathering and returning all ; if this becomes too slow, we
|
| 92 |
+
# can change to compute and return summary stats later on.
|
| 93 |
+
for dist_fn in distances:
|
| 94 |
+
ds = dist_fn(repr_s, repr_t)
|
| 95 |
+
all_ds.append(ds)
|
| 96 |
+
all_masks = mask
|
| 97 |
+
return all_ds, all_masks
|
| 98 |
+
|
| 99 |
+
return _eval_fn
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class Evaluator:
|
| 103 |
+
"""Distillation distance evaluator."""
|
| 104 |
+
|
| 105 |
+
def __init__(
|
| 106 |
+
self,
|
| 107 |
+
student_teacher_fwd,
|
| 108 |
+
data,
|
| 109 |
+
pp_fn,
|
| 110 |
+
distances,
|
| 111 |
+
what=('logits', 'logits'),
|
| 112 |
+
*,
|
| 113 |
+
devices,
|
| 114 |
+
**data_kw,
|
| 115 |
+
):
|
| 116 |
+
data = ds_core.get(**data)
|
| 117 |
+
pp_fn = pp_builder.get_preprocess_fn(pp_fn)
|
| 118 |
+
prefetch = data_kw.pop('prefetch', 1)
|
| 119 |
+
self.ds, self.steps = input_pipeline.make_for_inference(
|
| 120 |
+
data.get_tfdata(ordered=True),
|
| 121 |
+
pp_fn,
|
| 122 |
+
num_ex_per_process=data.num_examples_per_process(),
|
| 123 |
+
**data_kw,
|
| 124 |
+
)
|
| 125 |
+
self.data_iter = input_pipeline.start_global(self.ds, devices, prefetch)
|
| 126 |
+
dist_fns = tuple(get_dist_fn(**dist) for dist in distances)
|
| 127 |
+
self.dist_names = [
|
| 128 |
+
'_'.join(f'{k}={v}' for k, v in dist.items()) for dist in distances
|
| 129 |
+
]
|
| 130 |
+
mesh = jax.sharding.Mesh(devices, ('data',))
|
| 131 |
+
self.eval_fn = get_eval_fn(student_teacher_fwd, what, mesh, dist_fns)
|
| 132 |
+
|
| 133 |
+
def run(self, train_state):
|
| 134 |
+
"""Computes all metrics."""
|
| 135 |
+
all_ds = [[] for _ in self.dist_names]
|
| 136 |
+
for _, batch in zip(range(self.steps), self.data_iter):
|
| 137 |
+
mask = batch.pop('_mask')
|
| 138 |
+
batch_ds, batch_ms = self.eval_fn(train_state, batch, mask)
|
| 139 |
+
# All results are a replicated array shaped as follows:
|
| 140 |
+
# (local_devices, per_device_batch_size, elem_shape...)
|
| 141 |
+
# with each local device's entry being identical.
|
| 142 |
+
# So let's just take the first one to the host as numpy.
|
| 143 |
+
batch_ms = np.array(batch_ms)
|
| 144 |
+
for i, val in enumerate(batch_ds):
|
| 145 |
+
all_ds[i].append(np.array(val)[batch_ms == 1])
|
| 146 |
+
for name, ds in zip(self.dist_names, all_ds):
|
| 147 |
+
ds = np.concatenate(ds)
|
| 148 |
+
yield f'{name}/all', ds
|
| 149 |
+
yield f'{name}/avg', np.mean(ds)
|
| 150 |
+
yield f'{name}/min', np.min(ds)
|
| 151 |
+
yield f'{name}/max', np.max(ds)
|
Tipsomaly/model/big_vision/evaluators/proj/givt/coco_panoptic.py
ADDED
|
@@ -0,0 +1,401 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""COCO17 panoptic evaluation.
|
| 16 |
+
|
| 17 |
+
jax.jit-compatible fork of the evaluator from evaluators/proj/uvim.
|
| 18 |
+
"""
|
| 19 |
+
import functools
|
| 20 |
+
import itertools
|
| 21 |
+
import json
|
| 22 |
+
import os
|
| 23 |
+
import tempfile
|
| 24 |
+
import time
|
| 25 |
+
from typing import Any
|
| 26 |
+
import zipfile
|
| 27 |
+
|
| 28 |
+
from absl import flags
|
| 29 |
+
from absl import logging
|
| 30 |
+
from big_vision import input_pipeline
|
| 31 |
+
from big_vision import utils
|
| 32 |
+
from big_vision.datasets import core as ds_core
|
| 33 |
+
import big_vision.pp.builder as pp_builder
|
| 34 |
+
import jax
|
| 35 |
+
import jax.numpy as jnp
|
| 36 |
+
import numpy as np
|
| 37 |
+
from pycocotools.panopticapi import evaluation
|
| 38 |
+
import panopticapi_converters.twochannels2panoptic_coco_format as converter
|
| 39 |
+
import tensorflow as tf
|
| 40 |
+
import tensorflow_datasets as tfds
|
| 41 |
+
|
| 42 |
+
from tensorflow.io import gfile
|
| 43 |
+
|
| 44 |
+
# Temporary global flag to facilitate backwards compatability.
|
| 45 |
+
API = 'jit'
|
| 46 |
+
|
| 47 |
+
ROOT = os.environ.get('COCO_DATA_DIR', '.')
|
| 48 |
+
|
| 49 |
+
PANOPTIC_COCO_CATS_FILE = f'{ROOT}/panoptic_coco_categories.json'
|
| 50 |
+
PANOPTIC_2017 = {
|
| 51 |
+
'train': f'{ROOT}/panoptic_train2017.json',
|
| 52 |
+
'validation': f'{ROOT}/panoptic_val2017.json',
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
PANOPTIC_GT_ZIP = {
|
| 56 |
+
'train': f'{ROOT}/panoptic_train2017.zip',
|
| 57 |
+
'validation': f'{ROOT}/panoptic_val2017.zip',
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
# Note: global to avoid jax re-compiling across different evaluator instances.
|
| 62 |
+
@functools.cache
|
| 63 |
+
def _get_predict_fn(predict_fn, mesh=None):
|
| 64 |
+
"""Wrapper for jit-compiled predict function."""
|
| 65 |
+
|
| 66 |
+
# `out_shardings` annotation is needed because of the `all_gather` ops in the
|
| 67 |
+
# pmap implementation.
|
| 68 |
+
@functools.partial(jax.jit,
|
| 69 |
+
out_shardings=jax.sharding.NamedSharding(
|
| 70 |
+
mesh, jax.sharding.PartitionSpec()))
|
| 71 |
+
def _run_predict_fn(train_state, batch):
|
| 72 |
+
"""Run predict_fn and gather all outputs on all devices."""
|
| 73 |
+
y = predict_fn(train_state, batch)
|
| 74 |
+
res = {
|
| 75 |
+
'image/id': batch['image/id'],
|
| 76 |
+
'mask': batch['_mask'],
|
| 77 |
+
'y': jnp.stack([y['semantics'], y['instances']], axis=-1),
|
| 78 |
+
}
|
| 79 |
+
return res
|
| 80 |
+
return _run_predict_fn
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class Evaluator:
|
| 84 |
+
"""Panoptic segmentation evaluator: calls official COCO API."""
|
| 85 |
+
|
| 86 |
+
def __init__(
|
| 87 |
+
self,
|
| 88 |
+
predict_fn,
|
| 89 |
+
pp_fn,
|
| 90 |
+
batch_size,
|
| 91 |
+
data=None,
|
| 92 |
+
cache_final=True,
|
| 93 |
+
cache_raw=False,
|
| 94 |
+
prefetch=1,
|
| 95 |
+
save_dir=None,
|
| 96 |
+
*,
|
| 97 |
+
devices,
|
| 98 |
+
):
|
| 99 |
+
"""Panoptic segmentation evaluator: calls official COCO API.
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
predict_fn: jit-compilable function, which accepts arbitrary dictionaries
|
| 103 |
+
of parameters and data, where the data dictionary is produced by the
|
| 104 |
+
`pp_fn`. It is expected to output a 2-channel mask, where the first
|
| 105 |
+
channel encodes semantics, and the second channel encodes instance ids.
|
| 106 |
+
pp_fn: Preprocessing function, sepcified as string.
|
| 107 |
+
batch_size: Batch size.
|
| 108 |
+
data: Dict specifying name and split of the data set. Defaults to the
|
| 109 |
+
standard COCO (2017).
|
| 110 |
+
cache_final: Whether to cache the data after preprocessing - see
|
| 111 |
+
input_pipeline for details.
|
| 112 |
+
cache_raw: Whether to cache the raw data - see input_pipline for details.
|
| 113 |
+
prefetch: Number of batches to prefetch
|
| 114 |
+
save_dir: Directory to save the results in.
|
| 115 |
+
devices: List of jax devices.
|
| 116 |
+
"""
|
| 117 |
+
self.predict_fn = _get_predict_fn(
|
| 118 |
+
predict_fn, jax.sharding.Mesh(devices, ('devices',)))
|
| 119 |
+
|
| 120 |
+
data_specs = dict(name='coco/2017_panoptic',
|
| 121 |
+
data_dir=None, split='validation')
|
| 122 |
+
data_specs.update(data or {})
|
| 123 |
+
data = ds_core.get(**data_specs)
|
| 124 |
+
self.dataset, self.steps = input_pipeline.make_for_inference(
|
| 125 |
+
data.get_tfdata(ordered=True), batch_size=batch_size,
|
| 126 |
+
num_ex_per_process=data.num_examples_per_process(),
|
| 127 |
+
preprocess_fn=pp_builder.get_preprocess_fn(pp_fn),
|
| 128 |
+
cache_final=cache_final, cache_raw=cache_raw)
|
| 129 |
+
self.data_iter = input_pipeline.start_global(
|
| 130 |
+
self.dataset, devices, prefetch)
|
| 131 |
+
|
| 132 |
+
# Only process 0 runs conversion to png and calls into coco api.
|
| 133 |
+
if jax.process_index() == 0:
|
| 134 |
+
self.result_dir = tempfile.TemporaryDirectory()
|
| 135 |
+
(self.gt_folder, self.gt_json, self.categories_json,
|
| 136 |
+
self.remap, self.size_map) = _prepare_ground_truth(
|
| 137 |
+
data_specs['name'], data_specs['split'],
|
| 138 |
+
data_specs.get('data_dir'))
|
| 139 |
+
if save_dir:
|
| 140 |
+
self.save_dir = save_dir.format(workdir=flags.FLAGS.workdir)
|
| 141 |
+
gfile.makedirs(self.save_dir)
|
| 142 |
+
else:
|
| 143 |
+
self.save_dir = None
|
| 144 |
+
|
| 145 |
+
def _compute_png_predictions(
|
| 146 |
+
self, train_state: Any) -> Any:
|
| 147 |
+
"""Computes predictions and converts then to png to optimize memory use."""
|
| 148 |
+
count = 0
|
| 149 |
+
logging.info('Panoptic eval: running inference.')
|
| 150 |
+
for batch in itertools.islice(self.data_iter, self.steps):
|
| 151 |
+
out = self.predict_fn(train_state, batch)
|
| 152 |
+
|
| 153 |
+
if jax.process_index():
|
| 154 |
+
continue
|
| 155 |
+
|
| 156 |
+
out = jax.device_get(out)
|
| 157 |
+
mask = out['mask']
|
| 158 |
+
pan_recs = out['y'][mask]
|
| 159 |
+
ids = out['image/id'][mask]
|
| 160 |
+
|
| 161 |
+
for pan_rec, image_id in zip(pan_recs, ids):
|
| 162 |
+
sem = pan_rec[..., 0]
|
| 163 |
+
ins = pan_rec[..., 1]
|
| 164 |
+
|
| 165 |
+
sem_remapped = np.array(sem)
|
| 166 |
+
for v in np.unique(sem):
|
| 167 |
+
sem_remapped[sem == v] = self.remap[v]
|
| 168 |
+
sem = sem_remapped
|
| 169 |
+
|
| 170 |
+
pan_mask = np.stack([sem, ins, np.zeros_like(sem)], axis=-1)
|
| 171 |
+
pan_mask = utils.put_cpu(pan_mask)
|
| 172 |
+
pan_mask = _resize_nearest(pan_mask, self.size_map[image_id])
|
| 173 |
+
pan_mask_png = tf.io.encode_png(pan_mask.astype('uint8')).numpy()
|
| 174 |
+
|
| 175 |
+
fname = f'{self.result_dir.name}/{image_id:012d}.png'
|
| 176 |
+
with open(fname, 'wb') as f:
|
| 177 |
+
f.write(pan_mask_png)
|
| 178 |
+
count += 1
|
| 179 |
+
|
| 180 |
+
logging.log_every_n_seconds(
|
| 181 |
+
logging.INFO, 'Panoptic eval: processed %i examples so far.', 30,
|
| 182 |
+
count)
|
| 183 |
+
|
| 184 |
+
if jax.process_index():
|
| 185 |
+
return None
|
| 186 |
+
|
| 187 |
+
logging.info('Panoptic eval: inference done. Processed %d examples.', count)
|
| 188 |
+
return self.result_dir
|
| 189 |
+
|
| 190 |
+
def run(self, train_state):
|
| 191 |
+
"""Run panoptic segmentation evaluation.
|
| 192 |
+
|
| 193 |
+
Args:
|
| 194 |
+
train_state: pytree containing the model parameters.
|
| 195 |
+
|
| 196 |
+
Yields:
|
| 197 |
+
Tuples consisting of metric name and value.
|
| 198 |
+
"""
|
| 199 |
+
# Note result_dir is constant, but files inside are mutated.
|
| 200 |
+
result_dir = self._compute_png_predictions(train_state)
|
| 201 |
+
|
| 202 |
+
if jax.process_index():
|
| 203 |
+
return
|
| 204 |
+
|
| 205 |
+
if self.save_dir:
|
| 206 |
+
gfile.RecursivelyCopyDir(result_dir.name, self.save_dir, overwrite=True)
|
| 207 |
+
|
| 208 |
+
with tempfile.TemporaryDirectory() as pred_folder, \
|
| 209 |
+
tempfile.NamedTemporaryFile(mode='w') as pred_json:
|
| 210 |
+
|
| 211 |
+
logging.info('Panoptic eval: running conversion.')
|
| 212 |
+
converter.converter(
|
| 213 |
+
source_folder=result_dir.name,
|
| 214 |
+
images_json_file=self.gt_json,
|
| 215 |
+
categories_json_file=self.categories_json,
|
| 216 |
+
segmentations_folder=pred_folder,
|
| 217 |
+
predictions_json_file=pred_json.name)
|
| 218 |
+
logging.info('Panoptic eval: conversion done.')
|
| 219 |
+
|
| 220 |
+
logging.info('Panoptic eval: running metrics computation.')
|
| 221 |
+
res = evaluation.pq_compute(gt_json_file=self.gt_json,
|
| 222 |
+
gt_folder=self.gt_folder,
|
| 223 |
+
pred_json_file=pred_json.name,
|
| 224 |
+
pred_folder=pred_folder)
|
| 225 |
+
logging.info('Panoptic eval: metrics computation done.')
|
| 226 |
+
|
| 227 |
+
for k in ['All', 'Stuff', 'Things']:
|
| 228 |
+
for m in ['pq', 'rq', 'sq']:
|
| 229 |
+
yield f'{k}_{m}', res[k][m]
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def _prepare_ground_truth(dataset, split, data_dir):
|
| 233 |
+
if dataset == 'coco/2017_panoptic' and data_dir is None:
|
| 234 |
+
return _prepare_ground_truth_from_zipfiles(split)
|
| 235 |
+
else:
|
| 236 |
+
return _prepare_ground_truth_from_dataset(dataset, split, data_dir)
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
@functools.lru_cache(maxsize=None)
|
| 240 |
+
def _prepare_ground_truth_from_dataset(dataset, split, data_dir):
|
| 241 |
+
"""Prepare ground truth from a tf.data.Dataset.
|
| 242 |
+
|
| 243 |
+
Args:
|
| 244 |
+
dataset: TFDS-compatible dataset specification.
|
| 245 |
+
split: Data set split to use.
|
| 246 |
+
data_dir: Folder containing the data
|
| 247 |
+
|
| 248 |
+
Returns:
|
| 249 |
+
A tuple containing the folder containing the ground-truth data, the
|
| 250 |
+
ground truth annotations loaded from json, the categories loaded form json,
|
| 251 |
+
a map for remapping, and a map mapping image id to image size.
|
| 252 |
+
|
| 253 |
+
"""
|
| 254 |
+
tfds_dataset = tfds.builder(
|
| 255 |
+
dataset, data_dir=data_dir).as_dataset(split=split)
|
| 256 |
+
|
| 257 |
+
categories_json = _make_local_copy(PANOPTIC_COCO_CATS_FILE)
|
| 258 |
+
with gfile.GFile(categories_json, 'rb') as f:
|
| 259 |
+
categories = json.loads(f.read())
|
| 260 |
+
|
| 261 |
+
# Build map from tfds class ids to COCO class ids.
|
| 262 |
+
remap = {0: 0}
|
| 263 |
+
with gfile.GFile(categories_json, 'r') as f:
|
| 264 |
+
remap = {**remap, **{(i + 1): x['id'] for i, x in enumerate(categories)}}
|
| 265 |
+
|
| 266 |
+
gt_folder = tempfile.mkdtemp()
|
| 267 |
+
gfile.makedirs(gt_folder)
|
| 268 |
+
size_map = {}
|
| 269 |
+
annotations = []
|
| 270 |
+
images = []
|
| 271 |
+
for example in tfds_dataset:
|
| 272 |
+
image_id = int(example['image/id'])
|
| 273 |
+
panoptic_image = example['panoptic_image']
|
| 274 |
+
ann_ids = example['panoptic_objects']['id']
|
| 275 |
+
ann_labels = example['panoptic_objects']['label']
|
| 276 |
+
ann_iscrowd = example['panoptic_objects']['is_crowd']
|
| 277 |
+
ann_area = example['panoptic_objects']['area']
|
| 278 |
+
|
| 279 |
+
fname = f'{image_id:012d}.png'
|
| 280 |
+
with gfile.GFile(os.path.join(gt_folder, fname), 'wb') as f:
|
| 281 |
+
f.write(tf.io.encode_png(panoptic_image).numpy())
|
| 282 |
+
|
| 283 |
+
size_map[image_id] = (panoptic_image.shape[0], panoptic_image.shape[1])
|
| 284 |
+
|
| 285 |
+
segments_info = []
|
| 286 |
+
for i in range(len(ann_ids)):
|
| 287 |
+
segments_info.append({
|
| 288 |
+
'id': int(ann_ids[i]),
|
| 289 |
+
'category_id': remap[int(ann_labels[i] + 1)],
|
| 290 |
+
'iscrowd': int(ann_iscrowd[i]),
|
| 291 |
+
'area': int(ann_area[i]),
|
| 292 |
+
})
|
| 293 |
+
|
| 294 |
+
annotations.append({
|
| 295 |
+
'file_name': str(fname),
|
| 296 |
+
'image_id': int(image_id),
|
| 297 |
+
'segments_info': segments_info
|
| 298 |
+
})
|
| 299 |
+
images.append({
|
| 300 |
+
'id': image_id,
|
| 301 |
+
'file_name': f'{image_id:012d}.jpg',
|
| 302 |
+
})
|
| 303 |
+
|
| 304 |
+
# Write annotations.json needed for pq_compute.
|
| 305 |
+
gt_json = os.path.join(gt_folder, 'annotations.json')
|
| 306 |
+
with gfile.GFile(gt_json, 'wb') as f:
|
| 307 |
+
f.write(json.dumps({
|
| 308 |
+
'images': images,
|
| 309 |
+
'annotations': annotations,
|
| 310 |
+
'categories': categories,
|
| 311 |
+
}))
|
| 312 |
+
|
| 313 |
+
return gt_folder, gt_json, categories_json, remap, size_map
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def _prepare_ground_truth_from_zipfiles(split):
|
| 317 |
+
"""Prepare ground truth from coco zip files.
|
| 318 |
+
|
| 319 |
+
Args:
|
| 320 |
+
split: dataset split to prepare ground truth for.
|
| 321 |
+
|
| 322 |
+
Returns:
|
| 323 |
+
A tuple containing the folder containing the ground-truth data, the ground
|
| 324 |
+
truth annotations loaded from json, the categories loaded form json, a map
|
| 325 |
+
for remapping, and a map mapping image id to image size.
|
| 326 |
+
"""
|
| 327 |
+
split_prefix = split.split('[')[0]
|
| 328 |
+
if split_prefix not in ('train', 'validation'):
|
| 329 |
+
raise ValueError(f'Split {split} not supported')
|
| 330 |
+
|
| 331 |
+
# The following 4 calls are cached. This allows to save significant time
|
| 332 |
+
# in use cases like sweeping predict_fn hparams on the same run.
|
| 333 |
+
gt_json = _make_local_copy(PANOPTIC_2017[split_prefix])
|
| 334 |
+
gt_folder = _make_local_unzip_copy(PANOPTIC_GT_ZIP[split_prefix])
|
| 335 |
+
categories_json = _make_local_copy(PANOPTIC_COCO_CATS_FILE)
|
| 336 |
+
image_ids = _list_image_ids('coco/2017_panoptic', split)
|
| 337 |
+
|
| 338 |
+
gt_folder = os.path.join(
|
| 339 |
+
gt_folder, 'panoptic_val2017'
|
| 340 |
+
if split_prefix == 'validation' else 'panoptic_train2017')
|
| 341 |
+
|
| 342 |
+
# Build map from tfds class ids to COCO class ids.
|
| 343 |
+
remap = {0: 0}
|
| 344 |
+
with gfile.GFile(categories_json, 'r') as f:
|
| 345 |
+
remap = {**remap, **{(i + 1): x['id'] for i, x in enumerate(json.load(f))}}
|
| 346 |
+
|
| 347 |
+
# Filters gt_json to contain only annotations for images in dataset.
|
| 348 |
+
with gfile.GFile(gt_json) as f:
|
| 349 |
+
data = json.load(f)
|
| 350 |
+
logging.info(
|
| 351 |
+
'Panoptic eval: pre-filter %d annotations.',
|
| 352 |
+
len(data['annotations'])
|
| 353 |
+
)
|
| 354 |
+
data['images'] = [x for x in data['images'] if x['id'] in image_ids]
|
| 355 |
+
data['annotations'] = [
|
| 356 |
+
x for x in data['annotations'] if x['image_id'] in image_ids
|
| 357 |
+
]
|
| 358 |
+
logging.info(
|
| 359 |
+
'Panoptic eval: post-filter %d annotations.',
|
| 360 |
+
len(data['annotations'])
|
| 361 |
+
)
|
| 362 |
+
filtered_gt_json = tempfile.NamedTemporaryFile(delete=False).name
|
| 363 |
+
with open(filtered_gt_json, 'w') as f:
|
| 364 |
+
json.dump(data, f)
|
| 365 |
+
|
| 366 |
+
# Precompute images sizes.
|
| 367 |
+
size_map = {x['id']: (x['height'], x['width']) for x in data['images']}
|
| 368 |
+
|
| 369 |
+
return gt_folder, filtered_gt_json, categories_json, remap, size_map
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
@functools.lru_cache(maxsize=None)
|
| 373 |
+
def _list_image_ids(dataset, split):
|
| 374 |
+
d = tfds.load(dataset, split=split).map(lambda x: x['image/id'])
|
| 375 |
+
return frozenset(d.as_numpy_iterator())
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
@functools.lru_cache(maxsize=None)
|
| 379 |
+
def _make_local_copy(fname) -> str:
|
| 380 |
+
start = time.monotonic()
|
| 381 |
+
local_file = tempfile.NamedTemporaryFile(delete=False)
|
| 382 |
+
gfile.copy(fname, local_file.name, overwrite=True)
|
| 383 |
+
logging.info('Copy %s in %d seconds.', fname, time.monotonic() - start)
|
| 384 |
+
return local_file.name
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
@functools.lru_cache(maxsize=None)
|
| 388 |
+
def _make_local_unzip_copy(fname) -> str:
|
| 389 |
+
start = time.monotonic()
|
| 390 |
+
folder = tempfile.mkdtemp()
|
| 391 |
+
with tempfile.NamedTemporaryFile() as tmp_zip_file:
|
| 392 |
+
gfile.copy(fname, tmp_zip_file.name, overwrite=True)
|
| 393 |
+
with zipfile.ZipFile(tmp_zip_file.name, 'r') as f:
|
| 394 |
+
f.extractall(folder)
|
| 395 |
+
logging.info('Copy %s in %d seconds.', fname, time.monotonic() - start)
|
| 396 |
+
return folder
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
@utils.jit_cpu(static_argnums=(1,))
|
| 400 |
+
def _resize_nearest(image, shape):
|
| 401 |
+
return jax.image.resize(image, shape + image.shape[-1:], 'nearest')
|
Tipsomaly/model/big_vision/evaluators/proj/givt/nyu_depth.py
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Evaluation for NYU depth.
|
| 16 |
+
|
| 17 |
+
jax.jit-compatible fork of the evaluator from evaluators/proj/uvim.
|
| 18 |
+
|
| 19 |
+
At evaluation time the ground truth is cropped and clipped. Values outside of
|
| 20 |
+
the test crop or clipping range are not included in eval calculations.
|
| 21 |
+
|
| 22 |
+
In this evaluator, it is assume that the groud truth is already cropped, so the
|
| 23 |
+
entire image is evaluated. However, the evaluator does perform the clipping.
|
| 24 |
+
|
| 25 |
+
Reference implementations:
|
| 26 |
+
https://github.com/zhyever/Monocular-Depth-Estimation-Toolbox/blo(internal link)a0f341244260ff61541191a613dd74bc/depth/datasets/nyu.py
|
| 27 |
+
https://github.com/vinvino02/GLPDepth/blob/7f3c78df4ecd6e7c79fd0c4b73c95d61f4aa2121/code/utils/metrics.py
|
| 28 |
+
https://github.com/shariqfarooq123/AdaBins/blob/2fb686a66a304f0a719bc53d77412460af97fd61/evaluate.py
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
import functools
|
| 32 |
+
import itertools
|
| 33 |
+
|
| 34 |
+
from big_vision import input_pipeline
|
| 35 |
+
from big_vision import utils
|
| 36 |
+
from big_vision.datasets import core as ds_core
|
| 37 |
+
import big_vision.pp.builder as pp_builder
|
| 38 |
+
import jax
|
| 39 |
+
import jax.numpy as jnp
|
| 40 |
+
import numpy as np
|
| 41 |
+
|
| 42 |
+
# Temporary global flag to facilitate backwards compatability.
|
| 43 |
+
API = "jit"
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# Note: global to avoid jax re-compiling across different evaluator instances.
|
| 47 |
+
@functools.cache
|
| 48 |
+
def _get_predict_fn(predict_fn, mesh=None):
|
| 49 |
+
"""Wrapper for jit-compiled predict function."""
|
| 50 |
+
|
| 51 |
+
# `out_shardings` annotation is needed because of the `all_gather` ops in the
|
| 52 |
+
# pmap implementation.
|
| 53 |
+
@functools.partial(jax.jit,
|
| 54 |
+
out_shardings=jax.sharding.NamedSharding(
|
| 55 |
+
mesh, jax.sharding.PartitionSpec()))
|
| 56 |
+
def _run_predict_fn(train_state, batch):
|
| 57 |
+
"""Run predict_fn and gather all outputs on all devices."""
|
| 58 |
+
pred = predict_fn(train_state, batch)
|
| 59 |
+
return {"mask": batch["_mask"],
|
| 60 |
+
"gt": jnp.squeeze(batch["ground_truth"], axis=-1),
|
| 61 |
+
"y": pred["depth"]}
|
| 62 |
+
return _run_predict_fn
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class Evaluator:
|
| 66 |
+
"""Evaluator for NYU depth."""
|
| 67 |
+
|
| 68 |
+
def __init__(self,
|
| 69 |
+
predict_fn,
|
| 70 |
+
pp_fn,
|
| 71 |
+
batch_size,
|
| 72 |
+
data,
|
| 73 |
+
cache_final=True,
|
| 74 |
+
cache_raw=False,
|
| 75 |
+
prefetch=1,
|
| 76 |
+
min_depth=1e-3,
|
| 77 |
+
max_depth=10,
|
| 78 |
+
*,
|
| 79 |
+
devices):
|
| 80 |
+
"""Evaluator for NYU depth.
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
predict_fn: jit-compilable function, accepts arbitrary dictionaries of
|
| 84 |
+
parameters and data, where the data dictionary is produced by the
|
| 85 |
+
`pp_fn` op. It is expected to output a dict with `depth` containing an
|
| 86 |
+
2D array with the predicted depth. The prediction is resized to the
|
| 87 |
+
ground_truth size with nearest neighbour.
|
| 88 |
+
pp_fn: Preprocessing function, sepcified as string. `pp_fn` must also
|
| 89 |
+
output a 'ground_truth' as a 2D array of ground truth. Fruther, it has
|
| 90 |
+
to apply a crop, if one wants to compute metrics with the eval crop
|
| 91 |
+
typically used for NYU Depth metrics.
|
| 92 |
+
batch_size: Batch size.
|
| 93 |
+
data: Dict specifying name and split of the data set. Defaults to the
|
| 94 |
+
standard COCO (2017).
|
| 95 |
+
cache_final: Whether to cache the data after preprocessing - see
|
| 96 |
+
input_pipeline for details.
|
| 97 |
+
cache_raw: Whether to cache the raw data - see input_pipline for details.
|
| 98 |
+
prefetch: Number of batches to prefetch
|
| 99 |
+
min_depth: Minimum depth value.
|
| 100 |
+
max_depth: Maximum depth value.
|
| 101 |
+
devices: List of jax devices.
|
| 102 |
+
"""
|
| 103 |
+
self.min_depth = min_depth
|
| 104 |
+
self.max_depth = max_depth
|
| 105 |
+
self.predict_fn = _get_predict_fn(
|
| 106 |
+
predict_fn, jax.sharding.Mesh(devices, ("devices",)))
|
| 107 |
+
|
| 108 |
+
data = ds_core.get(**data)
|
| 109 |
+
self.dataset, self.steps = input_pipeline.make_for_inference(
|
| 110 |
+
data.get_tfdata(ordered=True), batch_size=batch_size,
|
| 111 |
+
num_ex_per_process=data.num_examples_per_process(),
|
| 112 |
+
preprocess_fn=pp_builder.get_preprocess_fn(pp_fn),
|
| 113 |
+
cache_final=cache_final, cache_raw=cache_raw)
|
| 114 |
+
self.data_iter = input_pipeline.start_global(
|
| 115 |
+
self.dataset, devices, prefetch)
|
| 116 |
+
|
| 117 |
+
def run(self, train_state):
|
| 118 |
+
"""Run NYU depth eval.
|
| 119 |
+
|
| 120 |
+
Args:
|
| 121 |
+
train_state: pytree containing the model parameters.
|
| 122 |
+
|
| 123 |
+
Yields:
|
| 124 |
+
Tuples consisting of metric name and value.
|
| 125 |
+
"""
|
| 126 |
+
rmses = []
|
| 127 |
+
abs_res = []
|
| 128 |
+
abs_logs = []
|
| 129 |
+
d1s = []
|
| 130 |
+
d2s = []
|
| 131 |
+
d3s = []
|
| 132 |
+
for batch in itertools.islice(self.data_iter, self.steps):
|
| 133 |
+
# Outputs is a dict with values shaped (gather/same, devices, batch, ...)
|
| 134 |
+
out = self.predict_fn(train_state, batch)
|
| 135 |
+
|
| 136 |
+
if jax.process_index(): # Host0 gets all preds and does eval.
|
| 137 |
+
continue
|
| 138 |
+
|
| 139 |
+
out = jax.device_get(out)
|
| 140 |
+
# Then the bool-indexing with mask resulting in flat (global_batch, ...)
|
| 141 |
+
out = jax.tree_map(lambda x: x[out["mask"]], out) # pylint:disable=cell-var-from-loop
|
| 142 |
+
|
| 143 |
+
for gt, pred in zip(out["gt"], out["y"]):
|
| 144 |
+
# put_cpu and conversion to numpy arrays below to avoid unwanted
|
| 145 |
+
# host-to-device transfers
|
| 146 |
+
pred, gt = utils.put_cpu((pred, gt))
|
| 147 |
+
pred = _resize_nearest(pred, (gt.shape[0], gt.shape[1]))
|
| 148 |
+
pred, gt = np.array(pred), np.array(gt)
|
| 149 |
+
valid_mask = np.logical_and(gt > self.min_depth, gt < self.max_depth)
|
| 150 |
+
|
| 151 |
+
rmses.append(_compute_rmse(gt[valid_mask], pred[valid_mask]))
|
| 152 |
+
abs_res.append(_compute_abs_re(gt[valid_mask], pred[valid_mask]))
|
| 153 |
+
abs_logs.append(_compute_abs_log(gt[valid_mask], pred[valid_mask]))
|
| 154 |
+
d1s.append(_compute_delta(gt[valid_mask], pred[valid_mask], order=1))
|
| 155 |
+
d2s.append(_compute_delta(gt[valid_mask], pred[valid_mask], order=2))
|
| 156 |
+
d3s.append(_compute_delta(gt[valid_mask], pred[valid_mask], order=3))
|
| 157 |
+
|
| 158 |
+
if jax.process_index(): # Host0 gets all preds and does eval.
|
| 159 |
+
return
|
| 160 |
+
|
| 161 |
+
yield "RMSE", np.mean(rmses)
|
| 162 |
+
yield "abs_RE", np.mean(abs_res)
|
| 163 |
+
yield "log10", np.mean(abs_logs)
|
| 164 |
+
yield "delta1", np.mean(d1s)
|
| 165 |
+
yield "delta2", np.mean(d2s)
|
| 166 |
+
yield "delta3", np.mean(d3s)
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
@utils.jit_cpu(static_argnums=(1,))
|
| 170 |
+
def _resize_nearest(image, shape):
|
| 171 |
+
return jax.image.resize(image, shape, "nearest")
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def _compute_rmse(gt, pred):
|
| 175 |
+
diff = gt - pred
|
| 176 |
+
return np.sqrt(np.mean(np.power(diff, 2)))
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def _compute_abs_re(gt, pred):
|
| 180 |
+
diff = np.abs(gt - pred)
|
| 181 |
+
return np.mean(diff / gt)
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def _compute_abs_log(gt, pred):
|
| 185 |
+
diff = np.abs(np.log10(gt) - np.log10(pred))
|
| 186 |
+
return np.mean(diff)
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def _compute_delta(gt, pred, order):
|
| 190 |
+
rel_diff = np.maximum(gt / pred, pred / gt)
|
| 191 |
+
return np.sum(rel_diff < 1.25**order) / rel_diff.size
|
Tipsomaly/model/big_vision/evaluators/proj/givt/save_predictions.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Evaluator to save predictions."""
|
| 16 |
+
# pylint: disable=consider-using-from-import
|
| 17 |
+
import functools
|
| 18 |
+
import io # pylint: disable=unused-import
|
| 19 |
+
import itertools
|
| 20 |
+
import os
|
| 21 |
+
|
| 22 |
+
from absl import flags
|
| 23 |
+
from absl import logging
|
| 24 |
+
from big_vision import input_pipeline
|
| 25 |
+
from big_vision.datasets import core as ds_core
|
| 26 |
+
import big_vision.pp.builder as pp_builder
|
| 27 |
+
import big_vision.utils as u
|
| 28 |
+
import jax
|
| 29 |
+
import numpy as np
|
| 30 |
+
|
| 31 |
+
from tensorflow.io import gfile # pylint: disable=unused-import
|
| 32 |
+
|
| 33 |
+
# Temporary global flag to facilitate backwards compatability.
|
| 34 |
+
API = 'jit'
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# Note: global to avoid jax re-compiling across different evaluator instances.
|
| 38 |
+
@functools.cache
|
| 39 |
+
def _get_predict_fn(predict_fn, mesh=None):
|
| 40 |
+
"""Wrapper for jit-compiled predict function."""
|
| 41 |
+
|
| 42 |
+
# `out_shardings` annotation is needed because of the `all_gather` ops in the
|
| 43 |
+
# pmap implementation.
|
| 44 |
+
@functools.partial(jax.jit,
|
| 45 |
+
out_shardings=jax.sharding.NamedSharding(
|
| 46 |
+
mesh, jax.sharding.PartitionSpec()))
|
| 47 |
+
def _run_predict_fn(train_state, batch):
|
| 48 |
+
"""Run predict_fn and gather all outputs on all devices."""
|
| 49 |
+
y = predict_fn(train_state, batch)
|
| 50 |
+
return {'inputs': batch, 'outputs': y, 'mask': batch['_mask']}
|
| 51 |
+
return _run_predict_fn
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class Evaluator:
|
| 55 |
+
"""Save predictions in "{FLAGS.workdir}/{outfile}".
|
| 56 |
+
|
| 57 |
+
Results can then be easily inspected in a notebook such as:
|
| 58 |
+
|
| 59 |
+
```
|
| 60 |
+
results = utils.load_checkpoint("<full_path_to_outfile>")
|
| 61 |
+
inputs, outputs = (results["inputs"], results["outputs"])
|
| 62 |
+
```
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
def __init__(self, predict_fn, pp_fn, batch_size, data, outfile,
|
| 66 |
+
cache_final=True, cache_raw=False, prefetch=1, *, devices):
|
| 67 |
+
self.predict_fn = _get_predict_fn(
|
| 68 |
+
predict_fn, jax.sharding.Mesh(devices, ('devices',)))
|
| 69 |
+
|
| 70 |
+
# Prepare data for each process and pad with zeros so all processes have the
|
| 71 |
+
# same number of batches.
|
| 72 |
+
data = ds_core.get(**data)
|
| 73 |
+
self.dataset, self.steps = input_pipeline.make_for_inference(
|
| 74 |
+
data.get_tfdata(ordered=True), batch_size=batch_size,
|
| 75 |
+
num_ex_per_process=data.num_examples_per_process(),
|
| 76 |
+
preprocess_fn=pp_builder.get_preprocess_fn(pp_fn),
|
| 77 |
+
cache_final=cache_final, cache_raw=cache_raw)
|
| 78 |
+
self.data_iter = input_pipeline.start_global(
|
| 79 |
+
self.dataset, devices, prefetch)
|
| 80 |
+
|
| 81 |
+
self.path = os.path.join(flags.FLAGS.workdir, outfile)
|
| 82 |
+
|
| 83 |
+
def run(self, train_state):
|
| 84 |
+
"""Compute all predictions, gather in main host and save in outfile."""
|
| 85 |
+
count = 0
|
| 86 |
+
outputs = []
|
| 87 |
+
for batch in itertools.islice(self.data_iter, self.steps):
|
| 88 |
+
out = self.predict_fn(train_state, batch)
|
| 89 |
+
if jax.process_index():
|
| 90 |
+
continue
|
| 91 |
+
|
| 92 |
+
out = jax.device_get(out)
|
| 93 |
+
# Note that we need to access `out['mask']` here `x` does not have that
|
| 94 |
+
# field during the tree map.
|
| 95 |
+
out = jax.tree_map(lambda x: x[out['mask']], out) # pylint: disable=cell-var-from-loop
|
| 96 |
+
count += out['mask'].shape[0]
|
| 97 |
+
out.pop('mask')
|
| 98 |
+
outputs.append(out)
|
| 99 |
+
|
| 100 |
+
logging.log_every_n_seconds(
|
| 101 |
+
logging.INFO, 'Save predictions: processed %i examples so far.', 30,
|
| 102 |
+
count)
|
| 103 |
+
|
| 104 |
+
if jax.process_index():
|
| 105 |
+
return
|
| 106 |
+
|
| 107 |
+
logging.info('Save predictions: processed %d examples.', count)
|
| 108 |
+
|
| 109 |
+
# Actually save in filesystem.
|
| 110 |
+
outputs = jax.tree_map(lambda *x: np.concatenate(x, axis=0), *outputs)
|
| 111 |
+
names_and_vals, _ = u.tree_flatten_with_names(outputs)
|
| 112 |
+
io_buffer = io.BytesIO()
|
| 113 |
+
np.savez_compressed(io_buffer, **{k: v for k, v in names_and_vals})
|
| 114 |
+
with gfile.GFile(self.path, 'wb') as f:
|
| 115 |
+
f.write(io_buffer.getvalue())
|
| 116 |
+
return
|
| 117 |
+
|
| 118 |
+
yield None # pylint: disable=unreachable
|
Tipsomaly/model/big_vision/evaluators/proj/image_text/contrastive.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Evaluator for the contrastive task.
|
| 16 |
+
|
| 17 |
+
DON'T COMPARE ACROSS RUNS, use for training health monitoring only.
|
| 18 |
+
|
| 19 |
+
Note that this evaluator's `ncorrect_minibatch` is only a rough proxy for
|
| 20 |
+
training progress and does not report the actual `ncorrect`: when the same
|
| 21 |
+
labels found multiple times in a batch, then the reported value is biased
|
| 22 |
+
towards lower values.
|
| 23 |
+
|
| 24 |
+
Also note that the `ncorrect_minibatch` is a function of batch size (it's a lot
|
| 25 |
+
easier to find correct values in small batches).
|
| 26 |
+
"""
|
| 27 |
+
import functools
|
| 28 |
+
|
| 29 |
+
from big_vision import input_pipeline
|
| 30 |
+
import big_vision.datasets.core as ds_core
|
| 31 |
+
import big_vision.pp.builder as pp_builder
|
| 32 |
+
import big_vision.utils as u
|
| 33 |
+
import jax
|
| 34 |
+
import jax.numpy as jnp
|
| 35 |
+
import numpy as np
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _all_gather(z):
|
| 39 |
+
"""All gather and flatten first two dims."""
|
| 40 |
+
gather_flat = lambda x: jnp.concatenate(jax.lax.all_gather(x, "batch"), 0)
|
| 41 |
+
return jax.tree_map(gather_flat, z)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# To avoid re-compiling the function for every new instance of the same
|
| 45 |
+
# evaluator on a different dataset!
|
| 46 |
+
@functools.lru_cache(None)
|
| 47 |
+
def get_eval_fn(predict_fn, use_global_batch):
|
| 48 |
+
"""Produces eval function, also applies pmap."""
|
| 49 |
+
|
| 50 |
+
@functools.partial(jax.pmap, axis_name="batch")
|
| 51 |
+
def _eval_fn(params, images, labels, mask):
|
| 52 |
+
zimg, ztxt, extras = predict_fn(params, images, labels)
|
| 53 |
+
|
| 54 |
+
if use_global_batch:
|
| 55 |
+
zimg, ztxt, mask = _all_gather((zimg, ztxt, mask))
|
| 56 |
+
|
| 57 |
+
# Temperature won't affect ranking for accuracy, but impacts loss magnitude.
|
| 58 |
+
losses, measurements = u.bidirectional_contrastive_loss(
|
| 59 |
+
zimg, ztxt, extras["t"], mask, reduction=False)
|
| 60 |
+
l = jax.lax.psum(losses * mask, axis_name="batch")
|
| 61 |
+
c = jax.lax.psum(measurements["ncorrect"] * mask, axis_name="batch")
|
| 62 |
+
n = jax.lax.psum(mask, axis_name="batch")
|
| 63 |
+
return c, l, n
|
| 64 |
+
|
| 65 |
+
return _eval_fn
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class Evaluator:
|
| 69 |
+
"""Contrastive evaluator."""
|
| 70 |
+
|
| 71 |
+
def __init__(self, predict_fn, data, pp_fn, batch_size,
|
| 72 |
+
use_global_batch, cache_final=True,
|
| 73 |
+
cache_raw=False, prefetch=1, label_key="labels"):
|
| 74 |
+
data = ds_core.get(**data)
|
| 75 |
+
pp_fn = pp_builder.get_preprocess_fn(pp_fn)
|
| 76 |
+
self.ds, self.steps = input_pipeline.make_for_inference(
|
| 77 |
+
data.get_tfdata(ordered=True), pp_fn, batch_size,
|
| 78 |
+
num_ex_per_process=data.num_examples_per_process(),
|
| 79 |
+
cache_final=cache_final, cache_raw=cache_raw)
|
| 80 |
+
self.data_iter = input_pipeline.start_input_pipeline(self.ds, prefetch)
|
| 81 |
+
self.eval_fn = get_eval_fn(predict_fn, use_global_batch)
|
| 82 |
+
self.label_key = label_key
|
| 83 |
+
|
| 84 |
+
def run(self, params):
|
| 85 |
+
"""Computes all metrics."""
|
| 86 |
+
l, c, nseen = 0, 0, 0
|
| 87 |
+
for _, batch in zip(range(self.steps), self.data_iter):
|
| 88 |
+
labels, mask = batch.pop(self.label_key), batch.pop("_mask")
|
| 89 |
+
batch_ncorrect, batch_losses, batch_n = self.eval_fn(
|
| 90 |
+
params, batch["image"], labels, mask)
|
| 91 |
+
# All results are a replicated array shaped as follows:
|
| 92 |
+
# (local_devices, per_device_batch_size, elem_shape...)
|
| 93 |
+
# with each local device's entry being identical as they got psum'd.
|
| 94 |
+
# So let's just take the first one to the host as numpy.
|
| 95 |
+
c += np.sum(np.array(batch_ncorrect[0]))
|
| 96 |
+
l += np.sum(np.array(batch_losses[0]))
|
| 97 |
+
nseen += np.sum(np.array(batch_n[0]))
|
| 98 |
+
yield ("ncorrect_minibatch", c / nseen)
|
| 99 |
+
yield ("loss", l / nseen)
|
Tipsomaly/model/big_vision/evaluators/proj/image_text/discriminative_classifier.py
ADDED
|
@@ -0,0 +1,440 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Discriminative zero-shot classification evaluator.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import functools
|
| 19 |
+
import time
|
| 20 |
+
|
| 21 |
+
from absl import logging
|
| 22 |
+
from big_vision import input_pipeline
|
| 23 |
+
from big_vision import utils
|
| 24 |
+
from big_vision.evaluators.proj.image_text import prompt_engineering
|
| 25 |
+
from big_vision.pp import ops_general # pylint: disable=unused-import
|
| 26 |
+
from big_vision.pp import ops_image # pylint: disable=unused-import
|
| 27 |
+
import big_vision.pp.builder as pp_builder
|
| 28 |
+
import jax
|
| 29 |
+
import jax.numpy as jnp
|
| 30 |
+
from jax.sharding import NamedSharding
|
| 31 |
+
from jax.sharding import PartitionSpec as P
|
| 32 |
+
import numpy as np
|
| 33 |
+
import tensorflow as tf
|
| 34 |
+
import tensorflow_datasets as tfds
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# Temporary global flag to facilitate backwards compatability. Will be removed
|
| 38 |
+
# by the end of year 2023.
|
| 39 |
+
API = "jit"
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
DATASET_NAMES = ("imagenet2012", "cifar100", "oxford_iiit_pet")
|
| 43 |
+
DEFAULT_OVERRIDES = (
|
| 44 |
+
("imagenet2012", (
|
| 45 |
+
("class_names", "clip"),
|
| 46 |
+
("split", "validation"),
|
| 47 |
+
)),
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def _with_infinite_padding(dataset):
|
| 52 |
+
"""Adds "infinite padding" to the dataset."""
|
| 53 |
+
filler_element = tf.nest.map_structure(
|
| 54 |
+
lambda spec: tf.zeros(spec.shape, spec.dtype)[None], dataset.element_spec)
|
| 55 |
+
filler_element["mask"] = [False]
|
| 56 |
+
filler_dataset = tf.data.Dataset.from_tensor_slices(filler_element)
|
| 57 |
+
dataset = dataset.map(
|
| 58 |
+
lambda features: dict(mask=True, **features),
|
| 59 |
+
num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
| 60 |
+
return dataset.concatenate(filler_dataset.repeat(None))
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
# This is needed so retrieval_test can replace dataset info.
|
| 64 |
+
def _get_dataset_info(builder):
|
| 65 |
+
return builder.info
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def prepare_datasets(img_dataset,
|
| 69 |
+
class_names,
|
| 70 |
+
*,
|
| 71 |
+
prompt_templates,
|
| 72 |
+
pp_img,
|
| 73 |
+
pp_txt,
|
| 74 |
+
cache_final=False,
|
| 75 |
+
pre_filter_fn=None,
|
| 76 |
+
class_name_offset=0):
|
| 77 |
+
"""Returns unbatched `ds_images, ds_texts` datasets."""
|
| 78 |
+
|
| 79 |
+
assert prompt_templates, "Must specify prompt templates (e.g. simply ['{}'])"
|
| 80 |
+
|
| 81 |
+
def expand_aliases(idx, class_name):
|
| 82 |
+
class_names = tf.strings.split(class_name, ",")
|
| 83 |
+
return tf.data.Dataset.from_tensor_slices((
|
| 84 |
+
tf.repeat([idx + class_name_offset], len(class_names), axis=0),
|
| 85 |
+
class_names,
|
| 86 |
+
))
|
| 87 |
+
|
| 88 |
+
def add_prompts(idx, class_name):
|
| 89 |
+
return tf.data.Dataset.from_tensor_slices({
|
| 90 |
+
"label": tf.repeat([idx], len(prompt_templates), axis=0),
|
| 91 |
+
"class_name": tf.repeat([class_name], len(prompt_templates), axis=0),
|
| 92 |
+
"prompt_template": prompt_templates,
|
| 93 |
+
})
|
| 94 |
+
|
| 95 |
+
def substitute_prompt(features):
|
| 96 |
+
parts = tf.strings.split(features["prompt_template"], "{}")
|
| 97 |
+
tf.debugging.assert_equal(len(parts), 2, features["prompt_template"])
|
| 98 |
+
return {
|
| 99 |
+
"label": features["label"],
|
| 100 |
+
"texts": tf.strings.join([parts[0], features["class_name"], parts[1]])
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
if pre_filter_fn:
|
| 104 |
+
img_dataset = img_dataset.filter(pre_filter_fn)
|
| 105 |
+
ds_images = img_dataset.map(
|
| 106 |
+
pp_builder.get_preprocess_fn(f"{pp_img}|keep('label', 'image')"))
|
| 107 |
+
ds_texts = tf.data.Dataset.from_tensor_slices(list(class_names)).enumerate(
|
| 108 |
+
).flat_map(expand_aliases).flat_map(add_prompts).map(substitute_prompt).map(
|
| 109 |
+
pp_builder.get_preprocess_fn(f"{pp_txt}|keep('label', 'labels')"))
|
| 110 |
+
|
| 111 |
+
if cache_final:
|
| 112 |
+
ds_images, ds_texts = ds_images.cache(), ds_texts.cache()
|
| 113 |
+
|
| 114 |
+
return ds_images, ds_texts
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def _split_and_batch(dataset_name, data_dir, class_names, batch_size, split,
|
| 118 |
+
get_ds):
|
| 119 |
+
"""Splits dataset, calls `get_ds` and returns padded + batched datasets."""
|
| 120 |
+
assert not batch_size % jax.device_count(), (
|
| 121 |
+
f"batch_size={batch_size} % jax.device_count()={jax.device_count()}")
|
| 122 |
+
builder = tfds.builder(dataset_name, data_dir=data_dir)
|
| 123 |
+
|
| 124 |
+
# Split class names (last process gets remainder).
|
| 125 |
+
if len(class_names) < jax.process_count():
|
| 126 |
+
# See (internal link) for more details.
|
| 127 |
+
class_names += [""] * (jax.process_count() - len(class_names))
|
| 128 |
+
per_process = len(class_names) // jax.process_count()
|
| 129 |
+
class_name_offset = per_process * jax.process_index()
|
| 130 |
+
if jax.process_index() == jax.process_count() - 1:
|
| 131 |
+
class_names = class_names[class_name_offset:]
|
| 132 |
+
else:
|
| 133 |
+
class_names = class_names[class_name_offset:class_name_offset + per_process]
|
| 134 |
+
|
| 135 |
+
ds_images, ds_texts = get_ds(
|
| 136 |
+
builder.as_dataset(split=tfds.split_for_jax_process(split)),
|
| 137 |
+
class_names,
|
| 138 |
+
class_name_offset=class_name_offset)
|
| 139 |
+
return (
|
| 140 |
+
_with_infinite_padding(ds_images).batch(batch_size),
|
| 141 |
+
_with_infinite_padding(ds_texts).batch(batch_size),
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def _average_embeddings(embeddings, *, labels, num_classes, normalize):
|
| 146 |
+
"""Computes per-class averages of `embeddings`."""
|
| 147 |
+
assert embeddings.ndim == 2, f"Expected {embeddings.ndim}==2"
|
| 148 |
+
assert labels.ndim == 1, f"Expected {labels.ndim}==1"
|
| 149 |
+
assert len(labels) == len(embeddings), (
|
| 150 |
+
f"Expected {len(labels)}=={len(embeddings)}")
|
| 151 |
+
|
| 152 |
+
byidx = [[] for _ in range(num_classes)]
|
| 153 |
+
for label, embedding in zip(labels, embeddings):
|
| 154 |
+
byidx[label].append(embedding)
|
| 155 |
+
missing = set(range(num_classes)) - set(
|
| 156 |
+
idx for idx, embs in enumerate(byidx) if len(embs))
|
| 157 |
+
assert not missing, f"Classes without embeddings: {missing}"
|
| 158 |
+
embeddings = [np.array(embedding).mean(axis=0) for embedding in byidx]
|
| 159 |
+
embeddings = np.stack(embeddings)
|
| 160 |
+
|
| 161 |
+
assert len(embeddings) == num_classes
|
| 162 |
+
if normalize:
|
| 163 |
+
embeddings /= 1e-8 + np.linalg.norm(embeddings, axis=1, keepdims=True)
|
| 164 |
+
return embeddings
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
class Evaluator:
|
| 168 |
+
"""Zero-shot classification evaluator."""
|
| 169 |
+
|
| 170 |
+
def __init__(self,
|
| 171 |
+
predict_fn,
|
| 172 |
+
*,
|
| 173 |
+
batch_size,
|
| 174 |
+
devices,
|
| 175 |
+
dataset_names=DATASET_NAMES,
|
| 176 |
+
data_dir=None,
|
| 177 |
+
class_names="dataset_info:label",
|
| 178 |
+
split="test",
|
| 179 |
+
prompt_templates="clip_paper",
|
| 180 |
+
canonicalize=True,
|
| 181 |
+
pp_img="resize(224)|value_range(-1,1)",
|
| 182 |
+
pp_txt="tokenize(max_len=16, eos='sticky', "
|
| 183 |
+
"pad_value=1, inkey='texts', outkey='labels')",
|
| 184 |
+
cache_final=False,
|
| 185 |
+
pre_filter_fn=None,
|
| 186 |
+
first_class_name_only=True,
|
| 187 |
+
dataset_overrides=DEFAULT_OVERRIDES,
|
| 188 |
+
async_delay=1):
|
| 189 |
+
"""Initializes a new zero-shot classification evaluator.
|
| 190 |
+
|
| 191 |
+
See `prepare_datasets()` for details on how the dataset is pre-processed.
|
| 192 |
+
|
| 193 |
+
Args:
|
| 194 |
+
predict_fn: Prediction function with signature
|
| 195 |
+
`zimg, ztxt, out = predict_fn(params, images, texts)`
|
| 196 |
+
batch_size: Global batch size.
|
| 197 |
+
devices: list of devices.
|
| 198 |
+
dataset_names: Names of TFDS datasets to evaluate on.
|
| 199 |
+
data_dir: Optional argument to `tfds.builder()`.
|
| 200 |
+
class_names: Usually specified as a string that is interpreted by
|
| 201 |
+
`prompt_engineering.get_class_names()` to look up class names.
|
| 202 |
+
Alternatively, this attribute can be a list of class names (using ","
|
| 203 |
+
to separate multiple aliases).
|
| 204 |
+
split: Which dataset split to use for evaluation.
|
| 205 |
+
prompt_templates: Specifies which prompt templates to use. See module
|
| 206 |
+
big_vision.evaluators.proj.image_text.prompte_engineering
|
| 207 |
+
for valid values.
|
| 208 |
+
canonicalize: Whether class names and prompt templates should be
|
| 209 |
+
canonicalized. See `prompt_engineering.py` for details.
|
| 210 |
+
pp_img: Preprocessing string for images. Preprocessed features should
|
| 211 |
+
contain key "image" with value that can be batched and is suitable for
|
| 212 |
+
the `images` argument of `predict_fn` input``.
|
| 213 |
+
pp_txt: Preprocessing string for texts. Can expect "texts" key as an input
|
| 214 |
+
(shape=[], dtype=string), and is expected to produce "labels" key that
|
| 215 |
+
is suitable for the `text` argument of `predict_fn` input.
|
| 216 |
+
cache_final: Wether preprocesse dataset should be cached.
|
| 217 |
+
pre_filter_fn: Predicate applied to the dataset for filtering records.
|
| 218 |
+
first_class_name_only: Whether only the first class name should be
|
| 219 |
+
considered (i.e. not using any aliases).
|
| 220 |
+
dataset_overrides: Mapping `dataset_name` to an optional dictionary that
|
| 221 |
+
can override parameters `dataset_name`, `data_dir`, `pp_img`, `pp_txt`,
|
| 222 |
+
`class_names`, `split`, `pre_filter_fn`, and the extra
|
| 223 |
+
`class_names_dataset_name`.
|
| 224 |
+
Works with tuple/dict of tuples/dicts.
|
| 225 |
+
async_delay: How many steps to wait before checking if all hosts have
|
| 226 |
+
finished their batch. A value > 1 allows for more parallelized
|
| 227 |
+
processing, but will results in more unnecessary steps with padded data.
|
| 228 |
+
"""
|
| 229 |
+
t0 = time.monotonic()
|
| 230 |
+
self.datasets = {}
|
| 231 |
+
self.prompt_templates = prompt_engineering.get_prompt_templates(
|
| 232 |
+
prompt_templates, canonicalize=canonicalize)
|
| 233 |
+
self._axis_name = "batch"
|
| 234 |
+
dataset_overrides = {k: dict(v) for k, v in dict(dataset_overrides).items()}
|
| 235 |
+
|
| 236 |
+
for dataset_name in dataset_names:
|
| 237 |
+
overrides = dataset_overrides.pop(dataset_name, {})
|
| 238 |
+
dataset_name_ = overrides.pop("dataset_name", dataset_name)
|
| 239 |
+
data_dir_ = overrides.pop("data_dir", data_dir)
|
| 240 |
+
class_names_dataset_name = overrides.pop("class_names_dataset_name",
|
| 241 |
+
dataset_name_)
|
| 242 |
+
class_names_ = overrides.pop("class_names", class_names)
|
| 243 |
+
class_names_ = prompt_engineering.get_class_names(
|
| 244 |
+
dataset_name=class_names_dataset_name,
|
| 245 |
+
source=class_names_,
|
| 246 |
+
canonicalize=canonicalize)
|
| 247 |
+
pp_img_ = overrides.pop("pp_img", pp_img)
|
| 248 |
+
pp_txt_ = overrides.pop("pp_txt", pp_txt)
|
| 249 |
+
cache_final_ = overrides.pop("cache_final", cache_final)
|
| 250 |
+
split_ = overrides.pop("split", split)
|
| 251 |
+
pre_filter_fn_ = overrides.pop("pre_filter_fn", pre_filter_fn)
|
| 252 |
+
prompt_templates_ = overrides.pop("prompt_templates", prompt_templates)
|
| 253 |
+
canonicalize_ = overrides.pop("canonicalize", canonicalize)
|
| 254 |
+
prompt_templates_ = prompt_engineering.get_prompt_templates(
|
| 255 |
+
prompt_templates_, canonicalize=canonicalize_)
|
| 256 |
+
assert not overrides, f"Unknown overrides {dataset_name}: {overrides}"
|
| 257 |
+
|
| 258 |
+
if first_class_name_only:
|
| 259 |
+
class_names_ = [name.split(",")[0] for name in class_names_]
|
| 260 |
+
ds_images, ds_texts = _split_and_batch(
|
| 261 |
+
dataset_name=dataset_name_,
|
| 262 |
+
data_dir=data_dir_,
|
| 263 |
+
class_names=class_names_,
|
| 264 |
+
batch_size=batch_size,
|
| 265 |
+
split=split_,
|
| 266 |
+
get_ds=functools.partial(
|
| 267 |
+
prepare_datasets,
|
| 268 |
+
pp_img=pp_img_,
|
| 269 |
+
pp_txt=pp_txt_,
|
| 270 |
+
cache_final=cache_final_,
|
| 271 |
+
pre_filter_fn=pre_filter_fn_,
|
| 272 |
+
prompt_templates=prompt_templates_))
|
| 273 |
+
self.datasets[dataset_name] = dict(
|
| 274 |
+
images=ds_images, texts=ds_texts, class_names=class_names_,
|
| 275 |
+
dataset_name=dataset_name_, split=split_)
|
| 276 |
+
|
| 277 |
+
assert not dataset_overrides, f"Extra overrides: {dataset_overrides}"
|
| 278 |
+
|
| 279 |
+
def embed_texts(train_state, texts):
|
| 280 |
+
"""Returns text embeddings."""
|
| 281 |
+
_, ztxt, _ = predict_fn(train_state, {"labels": texts})
|
| 282 |
+
return ztxt
|
| 283 |
+
|
| 284 |
+
def count_correct(train_state, return_embeddings, *, mask, labels, image,
|
| 285 |
+
ztxt):
|
| 286 |
+
"""Returns count of correct predictions (and optionally embeddings)."""
|
| 287 |
+
zimg, _, _ = predict_fn(train_state, {"image": image})
|
| 288 |
+
best_txt = (zimg @ ztxt.T).argmax(axis=1)
|
| 289 |
+
# labels has format [[1, -1, -1], [5, -1, -1], [7, 2, -1], ...]
|
| 290 |
+
# so here we count "any" correct, such that the counting matches the
|
| 291 |
+
# multilabel scenario described in "are we done with imagenet"
|
| 292 |
+
# (http://arxiv.org/abs/2006.07159) section 3.1
|
| 293 |
+
if labels.ndim == 1:
|
| 294 |
+
labels = labels[..., None]
|
| 295 |
+
assert labels.ndim == 2, labels.shape
|
| 296 |
+
matching = (best_txt[:, None] == labels).sum(axis=1)
|
| 297 |
+
correct = jnp.where(mask, (matching > 0).astype(jnp.int32), 0).sum()
|
| 298 |
+
correct = jnp.sum(correct)
|
| 299 |
+
if return_embeddings:
|
| 300 |
+
return correct, zimg
|
| 301 |
+
else:
|
| 302 |
+
return correct, None
|
| 303 |
+
|
| 304 |
+
self.devices = devices
|
| 305 |
+
self.mesh = jax.sharding.Mesh(devices, ("devices",))
|
| 306 |
+
|
| 307 |
+
self._embed_texts_p = jax.jit(
|
| 308 |
+
embed_texts, out_shardings=NamedSharding(self.mesh, P()))
|
| 309 |
+
self._count_correct_p = jax.jit(count_correct, static_argnums=(1,),
|
| 310 |
+
out_shardings=NamedSharding(self.mesh, P()))
|
| 311 |
+
self._count_p = jax.jit(jnp.sum,
|
| 312 |
+
out_shardings=NamedSharding(self.mesh, P()))
|
| 313 |
+
self._all_gather_p = jax.jit(
|
| 314 |
+
lambda x: x, out_shardings=NamedSharding(self.mesh, P()))
|
| 315 |
+
|
| 316 |
+
self._compiled = set()
|
| 317 |
+
assert async_delay > 0, f"async_delay must be >0, not {async_delay}"
|
| 318 |
+
self._async_delay = async_delay
|
| 319 |
+
logging.info("Initialized evaluator in %.1f seconds", time.monotonic() - t0)
|
| 320 |
+
|
| 321 |
+
def _embed_texts(self, train_state, dataset_name):
|
| 322 |
+
"""Returns per-class averaged text embeddings."""
|
| 323 |
+
t0 = time.monotonic()
|
| 324 |
+
logging.info("Starting text embedding...")
|
| 325 |
+
ns = []
|
| 326 |
+
embeddings = []
|
| 327 |
+
data = {"label": [], "mask": []}
|
| 328 |
+
|
| 329 |
+
ds_b = input_pipeline.start_global(
|
| 330 |
+
self.datasets[dataset_name]["texts"], self.devices)
|
| 331 |
+
for batch in ds_b:
|
| 332 |
+
ns.append(jax.device_get(self._count_p(batch["mask"])))
|
| 333 |
+
if len(ns) >= self._async_delay and ns[-self._async_delay] == 0:
|
| 334 |
+
break
|
| 335 |
+
|
| 336 |
+
embeddings.append(jax.device_get(self._embed_texts_p(
|
| 337 |
+
train_state, batch["labels"])))
|
| 338 |
+
for name in data:
|
| 339 |
+
data[name].append(jax.device_get(self._all_gather_p(batch[name])))
|
| 340 |
+
|
| 341 |
+
if self._embed_texts_p not in self._compiled:
|
| 342 |
+
logging.info("Compiled text embeddings in %.1fs", time.monotonic() - t0)
|
| 343 |
+
t0 = time.monotonic()
|
| 344 |
+
self._compiled.add(self._embed_texts_p)
|
| 345 |
+
|
| 346 |
+
ns = np.array(ns)
|
| 347 |
+
n = ns.sum()
|
| 348 |
+
data["embedding"] = embeddings
|
| 349 |
+
data = {k: np.concatenate(v, axis=0) for k, v in data.items()}
|
| 350 |
+
mask = data.pop("mask").astype(bool)
|
| 351 |
+
data = {k: v[mask] for k, v in data.items()}
|
| 352 |
+
data["average_embedding"] = _average_embeddings(
|
| 353 |
+
data["embedding"],
|
| 354 |
+
labels=data["label"],
|
| 355 |
+
num_classes=len(self.datasets[dataset_name]["class_names"]),
|
| 356 |
+
normalize=True)
|
| 357 |
+
|
| 358 |
+
logging.info("Embedded %s text in %d steps - ...%s", dataset_name, len(ns),
|
| 359 |
+
ns[-10:])
|
| 360 |
+
logging.info("Totalling %d text in %.1fs", n, time.monotonic() - t0)
|
| 361 |
+
logging.info("Total texts embeddings size %.1fM",
|
| 362 |
+
data["embedding"].nbytes / 1e6)
|
| 363 |
+
return data
|
| 364 |
+
|
| 365 |
+
def evaluate(self,
|
| 366 |
+
train_state,
|
| 367 |
+
dataset_name,
|
| 368 |
+
*,
|
| 369 |
+
return_embeddings=False):
|
| 370 |
+
"""Returns evaluation results."""
|
| 371 |
+
texts = self._embed_texts(train_state, dataset_name)
|
| 372 |
+
ztxt_p = texts["average_embedding"]
|
| 373 |
+
ztxt_p = utils.reshard(ztxt_p, NamedSharding(self.mesh, P()))
|
| 374 |
+
|
| 375 |
+
t0 = time.monotonic()
|
| 376 |
+
logging.info("Starting image embedding...")
|
| 377 |
+
|
| 378 |
+
ns = []
|
| 379 |
+
embeddings = []
|
| 380 |
+
corrects = []
|
| 381 |
+
data = {"mask": [], "label": []} if return_embeddings else {}
|
| 382 |
+
|
| 383 |
+
ds_b = input_pipeline.start_global(
|
| 384 |
+
self.datasets[dataset_name]["images"], self.devices)
|
| 385 |
+
for batch in ds_b:
|
| 386 |
+
ns.append(jax.device_get(self._count_p(batch["mask"])))
|
| 387 |
+
if len(ns) >= self._async_delay and ns[-self._async_delay] == 0:
|
| 388 |
+
break
|
| 389 |
+
|
| 390 |
+
labels = batch["label"]
|
| 391 |
+
correct_p, embs_p = self._count_correct_p(
|
| 392 |
+
train_state,
|
| 393 |
+
return_embeddings,
|
| 394 |
+
mask=batch["mask"],
|
| 395 |
+
labels=labels,
|
| 396 |
+
image=batch["image"],
|
| 397 |
+
ztxt=ztxt_p,
|
| 398 |
+
)
|
| 399 |
+
corrects.append(jax.device_get(correct_p))
|
| 400 |
+
if self._count_correct_p not in self._compiled:
|
| 401 |
+
logging.info("Compiled image embeddings in %.1fs",
|
| 402 |
+
time.monotonic() - t0)
|
| 403 |
+
t0 = time.monotonic()
|
| 404 |
+
self._compiled.add(self._count_correct_p)
|
| 405 |
+
|
| 406 |
+
if return_embeddings:
|
| 407 |
+
embeddings.append(jax.device_get(self._all_gather_p(embs_p)))
|
| 408 |
+
for name in data:
|
| 409 |
+
data[name].append(jax.device_get(self._all_gather_p(batch[name])))
|
| 410 |
+
|
| 411 |
+
ns = np.array(ns)
|
| 412 |
+
n = ns.sum()
|
| 413 |
+
correct = np.array(corrects).sum()
|
| 414 |
+
|
| 415 |
+
logging.info("Embedded %s image in %d steps - ...%s", dataset_name, len(ns),
|
| 416 |
+
ns[-10:])
|
| 417 |
+
logging.info("Totalling %d image in %.1fs", n, time.monotonic() - t0)
|
| 418 |
+
ret = {
|
| 419 |
+
"accuracy": correct / n,
|
| 420 |
+
"correct": correct,
|
| 421 |
+
"count": n,
|
| 422 |
+
}
|
| 423 |
+
logging.info("Dataset %s, results %s", dataset_name, ret)
|
| 424 |
+
|
| 425 |
+
if return_embeddings:
|
| 426 |
+
data["embedding"] = embeddings
|
| 427 |
+
data = {k: np.concatenate(v, axis=0) for k, v in data.items()}
|
| 428 |
+
logging.info("Total images embeddings size %.1fM",
|
| 429 |
+
data["embedding"].nbytes / 1e6)
|
| 430 |
+
mask = data.pop("mask").astype(bool)
|
| 431 |
+
ret["images"] = {k: v[mask] for k, v in data.items()}
|
| 432 |
+
ret["texts"] = texts
|
| 433 |
+
|
| 434 |
+
return ret
|
| 435 |
+
|
| 436 |
+
def run(self, train_state):
|
| 437 |
+
"""Returns metrics."""
|
| 438 |
+
return [(f"{dataset_name}_accuracy",
|
| 439 |
+
self.evaluate(train_state, dataset_name)["accuracy"])
|
| 440 |
+
for dataset_name in self.datasets]
|
Tipsomaly/model/big_vision/evaluators/proj/image_text/discriminative_classifier_test.py
ADDED
|
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Tests for discriminative_classifier."""
|
| 16 |
+
|
| 17 |
+
from unittest import mock
|
| 18 |
+
|
| 19 |
+
from big_vision.evaluators.proj.image_text import discriminative_classifier
|
| 20 |
+
from big_vision.pp import ops_general # pylint: disable=unused-import
|
| 21 |
+
from big_vision.pp import ops_image # pylint: disable=unused-import
|
| 22 |
+
from big_vision.pp.registry import Registry
|
| 23 |
+
import flax.linen as nn
|
| 24 |
+
import jax
|
| 25 |
+
import jax.numpy as jnp
|
| 26 |
+
import numpy as np
|
| 27 |
+
import tensorflow as tf
|
| 28 |
+
import tensorflow_datasets as tfds
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@Registry.register("preprocess_ops.test_texts2labels")
|
| 32 |
+
def _get_test_texts2labels():
|
| 33 |
+
|
| 34 |
+
def pp(features):
|
| 35 |
+
features["labels"] = tf.strings.to_number(features["texts"])
|
| 36 |
+
return features
|
| 37 |
+
|
| 38 |
+
return pp
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
@Registry.register("preprocess_ops.copy_from")
|
| 42 |
+
def _get_copy_from(**key_map):
|
| 43 |
+
|
| 44 |
+
def copy_from(d):
|
| 45 |
+
d = dict(d)
|
| 46 |
+
for k1, k2 in key_map.items():
|
| 47 |
+
d[k1] = d[k2]
|
| 48 |
+
return d
|
| 49 |
+
|
| 50 |
+
return copy_from
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class _Model(nn.Module):
|
| 54 |
+
|
| 55 |
+
@nn.compact
|
| 56 |
+
def __call__(self, image, texts):
|
| 57 |
+
self.param("x", lambda _: 0.)
|
| 58 |
+
|
| 59 |
+
def z(x):
|
| 60 |
+
if x is not None:
|
| 61 |
+
# Note that the returned vector is most similar with other vectors
|
| 62 |
+
# generated from the same underlying `x[:]`.
|
| 63 |
+
return jnp.stack([jnp.cos(x / 10.), jnp.sin(x / 10.)]).T
|
| 64 |
+
|
| 65 |
+
if texts is not None:
|
| 66 |
+
texts %= 5 # For testing `pre_filter_fn` below.
|
| 67 |
+
return z(image), z(texts), None
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class DiscriminativeClassifierTest(tf.test.TestCase):
|
| 71 |
+
|
| 72 |
+
def test_prepare_datasets(self):
|
| 73 |
+
|
| 74 |
+
def generator():
|
| 75 |
+
yield {
|
| 76 |
+
"image": tf.ones([5, 5, 3], tf.float32),
|
| 77 |
+
"label": 1,
|
| 78 |
+
}
|
| 79 |
+
yield {
|
| 80 |
+
"image": tf.ones([4, 4, 3], tf.float32),
|
| 81 |
+
"label": 2,
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
ds = tf.data.Dataset.from_generator(
|
| 85 |
+
generator,
|
| 86 |
+
output_signature={
|
| 87 |
+
"image": tf.TensorSpec(shape=[None, None, 3], dtype=tf.float32),
|
| 88 |
+
"label": tf.TensorSpec(shape=[], dtype=tf.int64),
|
| 89 |
+
})
|
| 90 |
+
class_names = [
|
| 91 |
+
"class1,class1a",
|
| 92 |
+
"class2",
|
| 93 |
+
]
|
| 94 |
+
prompt_templates = [
|
| 95 |
+
"test {}",
|
| 96 |
+
"test {} test",
|
| 97 |
+
]
|
| 98 |
+
ds_img, ds_txt = discriminative_classifier.prepare_datasets(
|
| 99 |
+
ds,
|
| 100 |
+
class_names,
|
| 101 |
+
prompt_templates=prompt_templates,
|
| 102 |
+
pp_img="resize(2)",
|
| 103 |
+
pp_txt="copy_from(labels='texts')",
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
it_img = iter(ds_img)
|
| 107 |
+
batch = next(it_img)
|
| 108 |
+
self.assertAllEqual(1, batch["label"])
|
| 109 |
+
self.assertAllEqual(tf.ones([2, 2, 3]), batch["image"])
|
| 110 |
+
batch = next(it_img)
|
| 111 |
+
self.assertAllEqual(2, batch["label"])
|
| 112 |
+
self.assertAllEqual(tf.ones([2, 2, 3]), batch["image"])
|
| 113 |
+
|
| 114 |
+
it_txt = iter(ds_txt)
|
| 115 |
+
batch = next(it_txt)
|
| 116 |
+
self.assertAllEqual(0, batch["label"])
|
| 117 |
+
self.assertAllEqual("test class1", batch["labels"])
|
| 118 |
+
batch = next(it_txt)
|
| 119 |
+
self.assertAllEqual(0, batch["label"])
|
| 120 |
+
self.assertAllEqual("test class1 test", batch["labels"])
|
| 121 |
+
batch = next(it_txt)
|
| 122 |
+
self.assertAllEqual(0, batch["label"])
|
| 123 |
+
self.assertAllEqual("test class1a", batch["labels"])
|
| 124 |
+
batch = next(it_txt)
|
| 125 |
+
self.assertAllEqual(0, batch["label"])
|
| 126 |
+
self.assertAllEqual("test class1a test", batch["labels"])
|
| 127 |
+
batch = next(it_txt)
|
| 128 |
+
self.assertAllEqual(1, batch["label"])
|
| 129 |
+
self.assertAllEqual("test class2", batch["labels"])
|
| 130 |
+
batch = next(it_txt)
|
| 131 |
+
self.assertAllEqual(1, batch["label"])
|
| 132 |
+
self.assertAllEqual("test class2 test", batch["labels"])
|
| 133 |
+
|
| 134 |
+
def test_average_embeddings(self):
|
| 135 |
+
self.assertAllEqual(jnp.array([
|
| 136 |
+
[2.], [4.], [8.],
|
| 137 |
+
]), discriminative_classifier._average_embeddings(
|
| 138 |
+
embeddings=jnp.array([
|
| 139 |
+
1., 3., 3., 1., # label1
|
| 140 |
+
8., 0., # label2
|
| 141 |
+
32., 0., 0., 0., # label3
|
| 142 |
+
])[..., None],
|
| 143 |
+
labels=jnp.array([
|
| 144 |
+
0, 0, # label1
|
| 145 |
+
0, 0, # label1 (alias)
|
| 146 |
+
1, 1, # label2
|
| 147 |
+
2, 2, # label3
|
| 148 |
+
2, 2, # label3 (alias)
|
| 149 |
+
], jnp.int32),
|
| 150 |
+
num_classes=3, normalize=False))
|
| 151 |
+
self.assertAllEqual(
|
| 152 |
+
jnp.array([
|
| 153 |
+
[2**-.5, 2**-.5],
|
| 154 |
+
]),
|
| 155 |
+
discriminative_classifier._average_embeddings(
|
| 156 |
+
embeddings=jnp.array([[2., 2.]]),
|
| 157 |
+
labels=jnp.array([0], jnp.int32),
|
| 158 |
+
num_classes=1,
|
| 159 |
+
normalize=True))
|
| 160 |
+
|
| 161 |
+
@mock.patch("big_vision.evaluators.proj."
|
| 162 |
+
"image_text.prompt_engineering.get_class_names")
|
| 163 |
+
@mock.patch("big_vision.evaluators.proj."
|
| 164 |
+
"image_text.prompt_engineering.get_prompt_templates")
|
| 165 |
+
@mock.patch("big_vision.evaluators.proj."
|
| 166 |
+
"image_text.discriminative_classifier._get_dataset_info")
|
| 167 |
+
def test_evaluate(self, get_dataset_info_mock, get_prompt_templates_mock,
|
| 168 |
+
get_class_names_mock):
|
| 169 |
+
per_device_batch_size = 10 # Make sure we have some unfiltered examples.
|
| 170 |
+
global_batch_size = per_device_batch_size * jax.device_count()
|
| 171 |
+
per_host_num_examples = int(
|
| 172 |
+
np.ceil(global_batch_size / jax.process_count()))
|
| 173 |
+
splits = {
|
| 174 |
+
"test":
|
| 175 |
+
tfds.core.SplitInfo(
|
| 176 |
+
name="test", shard_lengths=[per_host_num_examples], num_bytes=0)
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
model = _Model()
|
| 180 |
+
params = model.init(jax.random.PRNGKey(0), None, None)["params"]
|
| 181 |
+
|
| 182 |
+
prompt_templates = [
|
| 183 |
+
"test prompt 1 {}",
|
| 184 |
+
"test prompt 2 {}",
|
| 185 |
+
]
|
| 186 |
+
class_names = [
|
| 187 |
+
f"test_class_{i}" for i in range(10)
|
| 188 |
+
]
|
| 189 |
+
|
| 190 |
+
get_prompt_templates_mock.return_value = prompt_templates
|
| 191 |
+
get_class_names_mock.return_value = class_names
|
| 192 |
+
get_dataset_info_mock.return_value.splits = splits
|
| 193 |
+
|
| 194 |
+
def pre_filter_fn(features):
|
| 195 |
+
return features["label"] < 5 # matches `texts %= 5` above
|
| 196 |
+
|
| 197 |
+
dataset_name = "cifar10_test"
|
| 198 |
+
with tfds.testing.mock_data(num_examples=per_host_num_examples):
|
| 199 |
+
evaluator = discriminative_classifier.Evaluator(
|
| 200 |
+
lambda p, b: model.apply({"params": p},
|
| 201 |
+
b.get("image", None),
|
| 202 |
+
b.get("labels", None)),
|
| 203 |
+
dataset_names=[dataset_name],
|
| 204 |
+
prompt_templates="test_prompts",
|
| 205 |
+
batch_size=global_batch_size,
|
| 206 |
+
devices=jax.devices(),
|
| 207 |
+
pp_img="copy_from(image='label')",
|
| 208 |
+
pp_txt="copy_from(labels='label')",
|
| 209 |
+
dataset_overrides={
|
| 210 |
+
dataset_name: {
|
| 211 |
+
"dataset_name": "cifar10",
|
| 212 |
+
"class_names": "test_classes",
|
| 213 |
+
"pre_filter_fn": pre_filter_fn,
|
| 214 |
+
}
|
| 215 |
+
},
|
| 216 |
+
first_class_name_only=True,
|
| 217 |
+
)
|
| 218 |
+
results = evaluator.evaluate(
|
| 219 |
+
params,
|
| 220 |
+
dataset_name,
|
| 221 |
+
return_embeddings=True)
|
| 222 |
+
metrics = dict(evaluator.run(params))
|
| 223 |
+
|
| 224 |
+
# Assert all examples were processed.
|
| 225 |
+
self.assertLen(results["texts"]["embedding"],
|
| 226 |
+
len(class_names) * len(prompt_templates))
|
| 227 |
+
self.assertLen(results["texts"]["average_embedding"], len(class_names))
|
| 228 |
+
self.assertAllEqual(
|
| 229 |
+
sorted(results["texts"]["label"]),
|
| 230 |
+
[0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9])
|
| 231 |
+
# Note that above model makes perfect predictions by design.
|
| 232 |
+
self.assertEqual(1.0, results["accuracy"])
|
| 233 |
+
self.assertEqual(1.0, metrics[f"{dataset_name}_accuracy"])
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
if __name__ == "__main__":
|
| 237 |
+
tf.test.main()
|
Tipsomaly/model/big_vision/evaluators/proj/image_text/image_text_retrieval.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Evaluates image-text retrieval results."""
|
| 16 |
+
from typing import List, Mapping
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
|
| 20 |
+
RECALL_THRESHOLDS = (1, 5, 10)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def text_to_image_retrieval_eval(
|
| 24 |
+
dist_matrix: np.ndarray,
|
| 25 |
+
text_image_correspondence: List[int]) -> Mapping[str, float]:
|
| 26 |
+
"""Runs the text-to-image retrieval eval from the distance matrix.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
dist_matrix: Distance matrix between text and image embeddings (shape
|
| 30 |
+
N_IMAGES x N_TEXTS).
|
| 31 |
+
text_image_correspondence: Mapping between rows and columns of
|
| 32 |
+
`dist_matrix`, that is, a list of N_TEXTS integers n_i that represent that
|
| 33 |
+
the text embedding in column i corresponds to the image embedding in row
|
| 34 |
+
n_i. Please note that many texts can be assigned to the same image. For
|
| 35 |
+
instance, if we have 2 images and 4 texts (i.e. dist_matrix is 2x4), then
|
| 36 |
+
`text_image_correspondence = [0, 0, 1, 1]` means that the two first texts
|
| 37 |
+
correspond to the first image and the two last texts to the second image.
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
A dictionary with the Recall@k scores for k in RECALL_THRESHOLDS.
|
| 41 |
+
"""
|
| 42 |
+
per_text_ranks = dist_matrix.argsort(axis=0)
|
| 43 |
+
text_image_correspondence = np.array(text_image_correspondence)
|
| 44 |
+
|
| 45 |
+
def recall_at(k):
|
| 46 |
+
wins = per_text_ranks[:k, :] == text_image_correspondence[None]
|
| 47 |
+
return wins.any(axis=0).mean()
|
| 48 |
+
|
| 49 |
+
return {
|
| 50 |
+
f'Recall@{k}': recall_at(k)
|
| 51 |
+
for k in RECALL_THRESHOLDS
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def image_to_text_retrieval_eval(
|
| 56 |
+
dist_matrix: np.ndarray,
|
| 57 |
+
text_image_correspondence: List[int]) -> Mapping[str, float]:
|
| 58 |
+
"""Runs the image-to-text retrieval eval from the distance matrix.
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
dist_matrix: Distance matrix between text and image embeddings (shape
|
| 62 |
+
N_IMAGES x N_TEXTS).
|
| 63 |
+
text_image_correspondence: Mapping between rows and columns of
|
| 64 |
+
`dist_matrix`, that is, a list of N_TEXTS integers n_i that represent that
|
| 65 |
+
the text embedding in column i corresponds to the image embedding in row
|
| 66 |
+
n_i. Please note that many texts can be assigned to the same image. For
|
| 67 |
+
instance, if we have 2 images and 4 texts (i.e. dist_matrix is 2x4), then
|
| 68 |
+
`text_image_correspondence = [0, 0, 1, 1]` means that the two first texts
|
| 69 |
+
correspond to the first image and the two last texts to the second image.
|
| 70 |
+
|
| 71 |
+
Returns:
|
| 72 |
+
A dictionary with the Recall@k scores for k in RECALL_THRESHOLDS.
|
| 73 |
+
"""
|
| 74 |
+
per_image_ranks = dist_matrix.argsort(axis=1)
|
| 75 |
+
text_image_correspondence = np.array(text_image_correspondence)
|
| 76 |
+
|
| 77 |
+
def recall_at(k):
|
| 78 |
+
top_k_images = text_image_correspondence[per_image_ranks[:, :k]]
|
| 79 |
+
wins = top_k_images == np.arange(len(per_image_ranks))[:, None]
|
| 80 |
+
return wins.any(axis=1).mean()
|
| 81 |
+
|
| 82 |
+
return {
|
| 83 |
+
f'Recall@{k}': recall_at(k)
|
| 84 |
+
for k in RECALL_THRESHOLDS
|
| 85 |
+
}
|
Tipsomaly/model/big_vision/evaluators/proj/image_text/image_text_retrieval_test.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Unit tests for image_text_retrieval."""
|
| 16 |
+
from typing import Mapping
|
| 17 |
+
|
| 18 |
+
from absl.testing import absltest
|
| 19 |
+
from absl.testing import parameterized
|
| 20 |
+
from big_vision.evaluators.proj.image_text import image_text_retrieval
|
| 21 |
+
import numpy as np
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class ImTextRetrievalTest(parameterized.TestCase):
|
| 25 |
+
|
| 26 |
+
@parameterized.parameters(
|
| 27 |
+
(np.array([[0.0, 0.0, 0.1, 0.5, 0.1, 0.2, 0.5, 0.1],
|
| 28 |
+
[0.5, 0.4, 0.0, 0.0, 0.4, 0.2, 0.6, 0.4],
|
| 29 |
+
[0.5, 0.4, 0.1, 0.5, 0.0, 0.0, 0.8, 0.3],
|
| 30 |
+
[0.5, 0.4, 0.1, 0.5, 0.3, 0.2, 0.0, 0.0]]), {
|
| 31 |
+
'Recall@1': 1.0,
|
| 32 |
+
'Recall@5': 1.0,
|
| 33 |
+
'Recall@10': 1.0
|
| 34 |
+
}), #
|
| 35 |
+
(np.array([[0.8, 0.8, 0.1, 0.5, 0.1, 0.2, 0.5, 0.1],
|
| 36 |
+
[0.5, 0.4, 0.0, 0.0, 0.4, 0.2, 0.6, 0.4],
|
| 37 |
+
[0.5, 0.4, 0.1, 0.5, 0.0, 0.8, 0.8, 0.3],
|
| 38 |
+
[0.5, 0.4, 0.1, 0.5, 0.4, 0.2, 0.3, 0.3]]), {
|
| 39 |
+
'Recall@1': 0.5,
|
| 40 |
+
'Recall@5': 0.75,
|
| 41 |
+
'Recall@10': 1.0
|
| 42 |
+
}))
|
| 43 |
+
def test_image_to_text_retrieval_eval(self, dist_matrix: np.ndarray,
|
| 44 |
+
expected: Mapping[str, float]):
|
| 45 |
+
"""Checks `image_to_text_retrieval_eval`.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
dist_matrix: Distance matrix between image (rows) and text (columns).
|
| 49 |
+
expected: Expected eval results.
|
| 50 |
+
"""
|
| 51 |
+
self.assertEqual(
|
| 52 |
+
image_text_retrieval.image_to_text_retrieval_eval(
|
| 53 |
+
dist_matrix, [0, 0, 1, 1, 2, 2, 3, 3]), expected)
|
| 54 |
+
|
| 55 |
+
@parameterized.parameters(
|
| 56 |
+
(np.array([[0.0, 0.0, 0.1, 0.5, 0.1, 0.2, 0.5, 0.1],
|
| 57 |
+
[0.5, 0.4, 0.0, 0.0, 0.4, 0.2, 0.6, 0.4],
|
| 58 |
+
[0.5, 0.4, 0.1, 0.5, 0.0, 0.0, 0.8, 0.3],
|
| 59 |
+
[0.5, 0.4, 0.1, 0.5, 0.3, 0.2, 0.0, 0.0]]), {
|
| 60 |
+
'Recall@1': 1.0,
|
| 61 |
+
'Recall@5': 1.0,
|
| 62 |
+
'Recall@10': 1.0
|
| 63 |
+
}), #
|
| 64 |
+
(np.array([[0.8, 0.8, 0.1, 0.5, 0.1, 0.2, 0.1, 0.1],
|
| 65 |
+
[0.5, 0.4, 0.0, 0.0, 0.4, 0.2, 0.6, 0.4],
|
| 66 |
+
[0.5, 0.4, 0.1, 0.5, 0.0, 0.8, 0.8, 0.3],
|
| 67 |
+
[0.5, 0.4, 0.1, 0.5, 0.4, 0.2, 0.3, 0.3]]), {
|
| 68 |
+
'Recall@1': 0.375,
|
| 69 |
+
'Recall@5': 1.0,
|
| 70 |
+
'Recall@10': 1.0
|
| 71 |
+
}))
|
| 72 |
+
def test_image_text_retrieval(self, dist_matrix: np.ndarray,
|
| 73 |
+
expected: Mapping[str, float]):
|
| 74 |
+
"""Checks `text_to_image_retrieval_eval`.
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
dist_matrix: Distance matrix between image (rows) and text (columns).
|
| 78 |
+
expected: Expected eval results.
|
| 79 |
+
"""
|
| 80 |
+
self.assertEqual(
|
| 81 |
+
image_text_retrieval.text_to_image_retrieval_eval(
|
| 82 |
+
dist_matrix, [0, 0, 1, 1, 2, 2, 3, 3]), expected)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
if __name__ == '__main__':
|
| 86 |
+
absltest.main()
|
Tipsomaly/model/big_vision/evaluators/proj/image_text/prompt_engineering.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Utilities for generating zero-shot prompts."""
|
| 16 |
+
|
| 17 |
+
import re
|
| 18 |
+
import string
|
| 19 |
+
from typing import Sequence
|
| 20 |
+
|
| 21 |
+
from absl import logging
|
| 22 |
+
from big_vision.datasets.imagenet import class_names as imagenet_class_names
|
| 23 |
+
from big_vision.evaluators.proj.image_text import prompt_engineering_constants
|
| 24 |
+
import tensorflow_datasets as tfds
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
_CLASS_NAMES = { # For each dataset, maps from a source to its class names.
|
| 28 |
+
"imagenet2012": {
|
| 29 |
+
"clip": imagenet_class_names.CLIP_IMAGENET_CLASS_NAMES,
|
| 30 |
+
},
|
| 31 |
+
"grand-vision:imagenet2012": {
|
| 32 |
+
"clip": imagenet_class_names.CLIP_IMAGENET_CLASS_NAMES,
|
| 33 |
+
},
|
| 34 |
+
"imagenet_a": {
|
| 35 |
+
"clip": [
|
| 36 |
+
imagenet_class_names.CLIP_IMAGENET_CLASS_NAMES[i]
|
| 37 |
+
for i in imagenet_class_names.IMAGENET_A_LABELSET
|
| 38 |
+
]
|
| 39 |
+
},
|
| 40 |
+
"imagenet_r": {
|
| 41 |
+
"clip": [
|
| 42 |
+
imagenet_class_names.CLIP_IMAGENET_CLASS_NAMES[i]
|
| 43 |
+
for i in imagenet_class_names.IMAGENET_R_LABELSET
|
| 44 |
+
]
|
| 45 |
+
},
|
| 46 |
+
"imagenet_v2": {
|
| 47 |
+
"clip": imagenet_class_names.CLIP_IMAGENET_CLASS_NAMES,
|
| 48 |
+
},
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
_PROMPT_TEMPLATES = {
|
| 52 |
+
"class_name_only": ["{}"],
|
| 53 |
+
"clip_paper": prompt_engineering_constants.CLIP_PAPER_PROMPT_TEMPLATES,
|
| 54 |
+
"clip_best": prompt_engineering_constants.CLIP_BEST_PROMPT_TEMPLATES,
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def get_class_names(*, dataset_name, source="dataset_info", canonicalize=True):
|
| 59 |
+
"""Returns class name for `dataset_name` from `source`."""
|
| 60 |
+
if isinstance(source, str):
|
| 61 |
+
if source.startswith("dataset_info:"):
|
| 62 |
+
name = source[len("dataset_info:"):]
|
| 63 |
+
class_names = tfds.builder(dataset_name).info.features[name].names
|
| 64 |
+
else:
|
| 65 |
+
class_names = _CLASS_NAMES[dataset_name][source]
|
| 66 |
+
else:
|
| 67 |
+
assert isinstance(source, Sequence) and all(
|
| 68 |
+
map(lambda s: isinstance(s, str), source)), source
|
| 69 |
+
class_names = source
|
| 70 |
+
if canonicalize:
|
| 71 |
+
class_names = [
|
| 72 |
+
canonicalize_text(name, keep_punctuation_exact_string=",")
|
| 73 |
+
for name in class_names
|
| 74 |
+
]
|
| 75 |
+
logging.info("Using %d class_names: %s", len(class_names), class_names)
|
| 76 |
+
return class_names
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def get_prompt_templates(prompt_templates_name,
|
| 80 |
+
*,
|
| 81 |
+
canonicalize=True):
|
| 82 |
+
"""Returns prompt templates."""
|
| 83 |
+
prompts_templates = _PROMPT_TEMPLATES[prompt_templates_name]
|
| 84 |
+
if canonicalize:
|
| 85 |
+
prompts_templates = [
|
| 86 |
+
canonicalize_text(name, keep_punctuation_exact_string="{}")
|
| 87 |
+
for name in prompts_templates
|
| 88 |
+
]
|
| 89 |
+
logging.info("Using %d prompts_templates: %s", len(prompts_templates),
|
| 90 |
+
prompts_templates)
|
| 91 |
+
return prompts_templates
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def canonicalize_text(text, *, keep_punctuation_exact_string=None):
|
| 95 |
+
"""Returns canonicalized `text` (lowercase and puncuation removed).
|
| 96 |
+
|
| 97 |
+
Args:
|
| 98 |
+
text: string to be canonicalized.
|
| 99 |
+
keep_punctuation_exact_string: If provided, then this exact string kept.
|
| 100 |
+
For example providing '{}' will keep any occurrences of '{}' (but will
|
| 101 |
+
still remove '{' and '}' that appear separately).
|
| 102 |
+
"""
|
| 103 |
+
text = text.replace("_", " ")
|
| 104 |
+
if keep_punctuation_exact_string:
|
| 105 |
+
text = keep_punctuation_exact_string.join(
|
| 106 |
+
part.translate(str.maketrans("", "", string.punctuation))
|
| 107 |
+
for part in text.split(keep_punctuation_exact_string))
|
| 108 |
+
else:
|
| 109 |
+
text = text.translate(str.maketrans("", "", string.punctuation))
|
| 110 |
+
text = text.lower()
|
| 111 |
+
text = re.sub(r"\s+", " ", text)
|
| 112 |
+
return text.strip()
|
Tipsomaly/model/big_vision/evaluators/proj/image_text/prompt_engineering_constants.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Constants used by the module `prompt_engineering` in the same directory."""
|
| 16 |
+
|
| 17 |
+
CLIP_PAPER_PROMPT_TEMPLATES = [
|
| 18 |
+
'a bad photo of a {}.',
|
| 19 |
+
'a photo of many {}.',
|
| 20 |
+
'a sculpture of a {}.',
|
| 21 |
+
'a photo of the hard to see {}.',
|
| 22 |
+
'a low resolution photo of the {}.',
|
| 23 |
+
'a rendering of a {}.',
|
| 24 |
+
'graffiti of a {}.',
|
| 25 |
+
'a bad photo of the {}.',
|
| 26 |
+
'a cropped photo of the {}.',
|
| 27 |
+
'a tattoo of a {}.',
|
| 28 |
+
'the embroidered {}.',
|
| 29 |
+
'a photo of a hard to see {}.',
|
| 30 |
+
'a bright photo of a {}.',
|
| 31 |
+
'a photo of a clean {}.',
|
| 32 |
+
'a photo of a dirty {}.',
|
| 33 |
+
'a dark photo of the {}.',
|
| 34 |
+
'a drawing of a {}.',
|
| 35 |
+
'a photo of my {}.',
|
| 36 |
+
'the plastic {}.',
|
| 37 |
+
'a photo of the cool {}.',
|
| 38 |
+
'a close-up photo of a {}.',
|
| 39 |
+
'a black and white photo of the {}.',
|
| 40 |
+
'a painting of the {}.',
|
| 41 |
+
'a painting of a {}.',
|
| 42 |
+
'a pixelated photo of the {}.',
|
| 43 |
+
'a sculpture of the {}.',
|
| 44 |
+
'a bright photo of the {}.',
|
| 45 |
+
'a cropped photo of a {}.',
|
| 46 |
+
'a plastic {}.',
|
| 47 |
+
'a photo of the dirty {}.',
|
| 48 |
+
'a jpeg corrupted photo of a {}.',
|
| 49 |
+
'a blurry photo of the {}.',
|
| 50 |
+
'a photo of the {}.',
|
| 51 |
+
'a good photo of the {}.',
|
| 52 |
+
'a rendering of the {}.',
|
| 53 |
+
'a {} in a video game.',
|
| 54 |
+
'a photo of one {}.',
|
| 55 |
+
'a doodle of a {}.',
|
| 56 |
+
'a close-up photo of the {}.',
|
| 57 |
+
'a photo of a {}.',
|
| 58 |
+
'the origami {}.',
|
| 59 |
+
'the {} in a video game.',
|
| 60 |
+
'a sketch of a {}.',
|
| 61 |
+
'a doodle of the {}.',
|
| 62 |
+
'a origami {}.',
|
| 63 |
+
'a low resolution photo of a {}.',
|
| 64 |
+
'the toy {}.',
|
| 65 |
+
'a rendition of the {}.',
|
| 66 |
+
'a photo of the clean {}.',
|
| 67 |
+
'a photo of a large {}.',
|
| 68 |
+
'a rendition of a {}.',
|
| 69 |
+
'a photo of a nice {}.',
|
| 70 |
+
'a photo of a weird {}.',
|
| 71 |
+
'a blurry photo of a {}.',
|
| 72 |
+
'a cartoon {}.',
|
| 73 |
+
'art of a {}.',
|
| 74 |
+
'a sketch of the {}.',
|
| 75 |
+
'a embroidered {}.',
|
| 76 |
+
'a pixelated photo of a {}.',
|
| 77 |
+
'itap of the {}.',
|
| 78 |
+
'a jpeg corrupted photo of the {}.',
|
| 79 |
+
'a good photo of a {}.',
|
| 80 |
+
'a plushie {}.',
|
| 81 |
+
'a photo of the nice {}.',
|
| 82 |
+
'a photo of the small {}.',
|
| 83 |
+
'a photo of the weird {}.',
|
| 84 |
+
'the cartoon {}.',
|
| 85 |
+
'art of the {}.',
|
| 86 |
+
'a drawing of the {}.',
|
| 87 |
+
'a photo of the large {}.',
|
| 88 |
+
'a black and white photo of a {}.',
|
| 89 |
+
'the plushie {}.',
|
| 90 |
+
'a dark photo of a {}.',
|
| 91 |
+
'itap of a {}.',
|
| 92 |
+
'graffiti of the {}.',
|
| 93 |
+
'a toy {}.',
|
| 94 |
+
'itap of my {}.',
|
| 95 |
+
'a photo of a cool {}.',
|
| 96 |
+
'a photo of a small {}.',
|
| 97 |
+
'a tattoo of the {}.',
|
| 98 |
+
'{}',
|
| 99 |
+
]
|
| 100 |
+
|
| 101 |
+
CLIP_BEST_PROMPT_TEMPLATES = [
|
| 102 |
+
'itap of a {}.',
|
| 103 |
+
'a bad photo of the {}.',
|
| 104 |
+
'a origami {}.',
|
| 105 |
+
'a photo of the large {}.',
|
| 106 |
+
'a {} in a video game.',
|
| 107 |
+
'art of the {}.',
|
| 108 |
+
'a photo of the small {}.',
|
| 109 |
+
'{}',
|
| 110 |
+
]
|
Tipsomaly/model/big_vision/evaluators/proj/image_text/prompt_engineering_test.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Tests for prompt_engineering."""
|
| 16 |
+
|
| 17 |
+
from absl.testing import absltest
|
| 18 |
+
from big_vision.evaluators.proj.image_text import prompt_engineering
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class PromptEngineeringTest(absltest.TestCase):
|
| 22 |
+
|
| 23 |
+
def test_canonicalize_text(self):
|
| 24 |
+
self.assertEqual(prompt_engineering.canonicalize_text("test_test"), "test test")
|
| 25 |
+
self.assertEqual(
|
| 26 |
+
prompt_engineering.canonicalize_text("test___test"), "test test")
|
| 27 |
+
self.assertEqual(prompt_engineering.canonicalize_text("test"), "test")
|
| 28 |
+
self.assertEqual(prompt_engineering.canonicalize_text("test."), "test")
|
| 29 |
+
self.assertEqual(prompt_engineering.canonicalize_text(" test "), "test")
|
| 30 |
+
self.assertEqual(
|
| 31 |
+
prompt_engineering.canonicalize_text("test\ntest"), "test test")
|
| 32 |
+
self.assertEqual(
|
| 33 |
+
prompt_engineering.canonicalize_text("test test"), "test test")
|
| 34 |
+
self.assertEqual(prompt_engineering.canonicalize_text("test {}"), "test")
|
| 35 |
+
self.assertEqual(
|
| 36 |
+
prompt_engineering.canonicalize_text(
|
| 37 |
+
"test {}", keep_punctuation_exact_string="{}"), "test {}")
|
| 38 |
+
self.assertEqual(
|
| 39 |
+
prompt_engineering.canonicalize_text(
|
| 40 |
+
" test {}...", keep_punctuation_exact_string="{}"), "test {}")
|
| 41 |
+
self.assertEqual(
|
| 42 |
+
prompt_engineering.canonicalize_text(
|
| 43 |
+
"test {} {} {}", keep_punctuation_exact_string="{}"),
|
| 44 |
+
"test {} {} {}")
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
if __name__ == "__main__":
|
| 48 |
+
absltest.main()
|
Tipsomaly/model/big_vision/evaluators/proj/image_text/retrieval.py
ADDED
|
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Multi-host image->text and text->image retrieval evaluation.
|
| 16 |
+
|
| 17 |
+
Example how to add to config:
|
| 18 |
+
|
| 19 |
+
config.evals {}
|
| 20 |
+
config.evals.retieval = dict(log_steps=1200, type='proj.image_text.retrieval')
|
| 21 |
+
config.evals.retrieval.dataset = 'coco_captions'
|
| 22 |
+
config.evals.retrieval.txt_name = ('captions', 'text')
|
| 23 |
+
# Note that initial "decode|" is not needed.
|
| 24 |
+
config.evals.retrieval.pp_img = 'resize(224)|value_range(-1,1)'
|
| 25 |
+
# Raw text strings use key "texts" in feature dict. The evaluator expects
|
| 26 |
+
# tokenized text with key "labels".
|
| 27 |
+
config.evals.retrieval.pp_txt = (
|
| 28 |
+
'tokenize(max_len=16, eos="sticky", pad_value=1, inkey="texts", '
|
| 29 |
+
' outkey="labels")')
|
| 30 |
+
|
| 31 |
+
Example to support precomputed data:
|
| 32 |
+
See `big_vision/configs/proj/image_text/lit.py`.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
import functools
|
| 36 |
+
import operator
|
| 37 |
+
import time
|
| 38 |
+
|
| 39 |
+
from absl import logging
|
| 40 |
+
from big_vision import input_pipeline
|
| 41 |
+
from big_vision.evaluators.proj.image_text import image_text_retrieval
|
| 42 |
+
import big_vision.pp.builder as pp_builder
|
| 43 |
+
import jax
|
| 44 |
+
import jax.numpy as jnp
|
| 45 |
+
from jax.sharding import NamedSharding
|
| 46 |
+
from jax.sharding import PartitionSpec as P
|
| 47 |
+
import numpy as np
|
| 48 |
+
import tensorflow as tf
|
| 49 |
+
import tensorflow_datasets as tfds
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
# Temporary global flag to facilitate backwards compatability. Will be removed
|
| 53 |
+
# by the end of year 2023.
|
| 54 |
+
API = "jit"
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def _with_infinite_padding(dataset):
|
| 58 |
+
"""Adds "infinite padding" to the dataset."""
|
| 59 |
+
filler_element = tf.nest.map_structure(
|
| 60 |
+
lambda spec: tf.zeros(spec.shape, spec.dtype)[None], dataset.element_spec)
|
| 61 |
+
filler_element["mask"] = [False]
|
| 62 |
+
filler_dataset = tf.data.Dataset.from_tensor_slices(filler_element)
|
| 63 |
+
dataset = dataset.map(
|
| 64 |
+
lambda features: dict(mask=True, **features),
|
| 65 |
+
num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
| 66 |
+
return dataset.concatenate(filler_dataset.repeat(None))
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
# This is needed so retrieval_test can replace dataset info.
|
| 70 |
+
def _get_dataset_info(builder):
|
| 71 |
+
return builder.info
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def prepare_datasets(
|
| 75 |
+
dataset, *, pp_img, pp_txt, txt_name, offset=0, cache_final=False
|
| 76 |
+
):
|
| 77 |
+
"""Returns unbatched `ds_images, ds_texts` datasets.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
dataset: An image-text `tf.data.Dataset` that is expected to contain the
|
| 81 |
+
following features: "image" (dtype=uint8, shape=[None, None, 3]),
|
| 82 |
+
`txt_name` (dtype=string, shape=[None]).
|
| 83 |
+
pp_img: String defining pre-processing for images. The pre-processing can
|
| 84 |
+
expect the following features to be prepared: "image", "id". The
|
| 85 |
+
pre-processing should convert the "image" (dtype=uint8,
|
| 86 |
+
shape=[None, None, 3]) to "image" (dtype=float32, shape=[sz, sz, 3]).
|
| 87 |
+
pp_txt: String defining pre-processing for text. The pre-processing can
|
| 88 |
+
expect the following features to be prepared: "texts", "id", "caption_id".
|
| 89 |
+
The pre-processing should convert the "texts" (dtype=string, shape=[])
|
| 90 |
+
into a tokenized "labels" (dtype=int32, shape=[max_len]).
|
| 91 |
+
txt_name: Name of the text feature to unroll in the original `dataset`. Can
|
| 92 |
+
be a simple string feature name, or an iterable of strings to specify a
|
| 93 |
+
nested feature (e.g. for "coco_captions", this would be
|
| 94 |
+
`('captions', 'text')`).
|
| 95 |
+
offset: Offset that should be added to enumerated examples to generate IDs.
|
| 96 |
+
In a multi-host setup, this is typically set to a value large enough to
|
| 97 |
+
make all IDs distinct.
|
| 98 |
+
cache_final: Whether the dataset should be cached.
|
| 99 |
+
|
| 100 |
+
Returns:
|
| 101 |
+
Image and text datasets.
|
| 102 |
+
"""
|
| 103 |
+
|
| 104 |
+
def get_feature_value(data, feature_name):
|
| 105 |
+
if isinstance(feature_name, str):
|
| 106 |
+
feature_name = [feature_name]
|
| 107 |
+
return functools.reduce(operator.getitem, feature_name, data)
|
| 108 |
+
|
| 109 |
+
def get_captions(idx, features):
|
| 110 |
+
"""Returns a dataset with unrolled "caption" for every example."""
|
| 111 |
+
texts = get_feature_value(features, txt_name)
|
| 112 |
+
texts = tf.experimental.numpy.atleast_1d(texts) # For single-text GT.
|
| 113 |
+
texts_n = tf.shape(texts)[0]
|
| 114 |
+
return tf.data.Dataset.from_tensor_slices({
|
| 115 |
+
"id": tf.tile([idx + offset], [texts_n]),
|
| 116 |
+
"caption_i": tf.stack(tf.range(texts_n)),
|
| 117 |
+
"texts": tf.stack(texts),
|
| 118 |
+
})
|
| 119 |
+
|
| 120 |
+
def add_id(idx, features):
|
| 121 |
+
return {**features, "id": idx + offset}
|
| 122 |
+
|
| 123 |
+
ds_images = dataset.enumerate().map(add_id).map(
|
| 124 |
+
pp_builder.get_preprocess_fn(f"{pp_img}|keep('id', 'image')"))
|
| 125 |
+
ds_texts = dataset.enumerate().flat_map(get_captions).map(
|
| 126 |
+
pp_builder.get_preprocess_fn(
|
| 127 |
+
f"{pp_txt}|keep('id', 'caption_i', 'labels')"))
|
| 128 |
+
if cache_final:
|
| 129 |
+
ds_images, ds_texts = ds_images.cache(), ds_texts.cache()
|
| 130 |
+
return ds_images, ds_texts
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def _split_and_batch(dataset_name, batch_size, split, get_ds, data_dir=None):
|
| 134 |
+
"""Splits dataset, calls `get_ds` and returns padded + batched datasets."""
|
| 135 |
+
assert not batch_size % jax.device_count(), (
|
| 136 |
+
f"batch_size={batch_size} % jax.device_count()={jax.device_count()}")
|
| 137 |
+
builder = tfds.builder(dataset_name, data_dir=data_dir)
|
| 138 |
+
info = _get_dataset_info(builder)
|
| 139 |
+
num_examples = info.splits[split].num_examples
|
| 140 |
+
ds_images, ds_texts = get_ds(
|
| 141 |
+
builder.as_dataset(split=tfds.split_for_jax_process(split)),
|
| 142 |
+
offset=jax.process_index() * num_examples,
|
| 143 |
+
)
|
| 144 |
+
return (
|
| 145 |
+
_with_infinite_padding(ds_images).batch(batch_size),
|
| 146 |
+
_with_infinite_padding(ds_texts).batch(batch_size),
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class Evaluator:
|
| 151 |
+
"""Image/text retrieval evaluator."""
|
| 152 |
+
|
| 153 |
+
def __init__(self,
|
| 154 |
+
predict_fn,
|
| 155 |
+
*,
|
| 156 |
+
dataset,
|
| 157 |
+
pp_img,
|
| 158 |
+
pp_txt,
|
| 159 |
+
txt_name,
|
| 160 |
+
batch_size,
|
| 161 |
+
devices,
|
| 162 |
+
data_dir=None,
|
| 163 |
+
split="test",
|
| 164 |
+
cache_final=True):
|
| 165 |
+
"""Initializes a new zero-shot image/text retrieval evaluator.
|
| 166 |
+
|
| 167 |
+
See `prepare_datasets()` for details on how the dataset is pre-processed.
|
| 168 |
+
|
| 169 |
+
Args:
|
| 170 |
+
predict_fn: Prediction function with signature
|
| 171 |
+
`zimg, ztxt, out = predict_fn(params, images, texts)`
|
| 172 |
+
dataset: The TFDS dataset name of the eval data.
|
| 173 |
+
pp_img: Preprocessing string for images. Preprocessed features should
|
| 174 |
+
contain key "image" with value that can be batched and is suitable for
|
| 175 |
+
`predict_fn(images)` input``.
|
| 176 |
+
pp_txt: Preprocessing string for texts. Can expect "texts" key as an input
|
| 177 |
+
(shape=[], dtype=string), and is expected to produce "labels" key that
|
| 178 |
+
is suitable for `predict_fn(texts)` input.
|
| 179 |
+
txt_name: The name of the feature of captions (can be a tuple to look up a
|
| 180 |
+
value in a nested feature dictionary). Expected shape=[None],
|
| 181 |
+
dtype=string. specified then items are used as lookup path.
|
| 182 |
+
batch_size: Global batch size.
|
| 183 |
+
devices: list of devices.
|
| 184 |
+
data_dir: Optional dir to load the TFDS dataset from.
|
| 185 |
+
split: The split of the eval data.
|
| 186 |
+
cache_final: Wether preprocessed dataset should be cached.
|
| 187 |
+
"""
|
| 188 |
+
self.ds_images, self.ds_texts = _split_and_batch(
|
| 189 |
+
dataset,
|
| 190 |
+
batch_size,
|
| 191 |
+
split,
|
| 192 |
+
functools.partial(
|
| 193 |
+
prepare_datasets,
|
| 194 |
+
pp_img=pp_img,
|
| 195 |
+
pp_txt=pp_txt,
|
| 196 |
+
txt_name=txt_name,
|
| 197 |
+
cache_final=cache_final,
|
| 198 |
+
),
|
| 199 |
+
data_dir=data_dir,
|
| 200 |
+
)
|
| 201 |
+
self._axis_name = "batch"
|
| 202 |
+
|
| 203 |
+
self.devices = devices
|
| 204 |
+
mesh = jax.sharding.Mesh(devices, ("devices",))
|
| 205 |
+
|
| 206 |
+
def embed_images(train_state, images):
|
| 207 |
+
zimg, _, _ = predict_fn(train_state, {"image": images})
|
| 208 |
+
return zimg
|
| 209 |
+
|
| 210 |
+
def embed_texts(train_state, texts):
|
| 211 |
+
_, ztxt, _ = predict_fn(train_state, {"labels": texts})
|
| 212 |
+
return ztxt
|
| 213 |
+
|
| 214 |
+
self._embed_images_p = jax.jit(embed_images,
|
| 215 |
+
out_shardings=NamedSharding(mesh, P()))
|
| 216 |
+
self._embed_texts_p = jax.jit(embed_texts,
|
| 217 |
+
out_shardings=NamedSharding(mesh, P()))
|
| 218 |
+
self._all_gather_p = jax.jit(
|
| 219 |
+
lambda x: x, out_shardings=NamedSharding(mesh, P()))
|
| 220 |
+
self._count_p = jax.jit(jnp.sum, out_shardings=NamedSharding(mesh, P()))
|
| 221 |
+
self._compiled = set()
|
| 222 |
+
|
| 223 |
+
def _embed(self, name, train_state, ds, embed_fn, id_names):
|
| 224 |
+
"""Embeds features name `name` using `embed_fn`.
|
| 225 |
+
|
| 226 |
+
Args:
|
| 227 |
+
name: Feature name to be embedded.
|
| 228 |
+
train_state: train_state for the predict_fn.
|
| 229 |
+
ds: The dataset.
|
| 230 |
+
embed_fn: A pmapped function that returns the embeddings.
|
| 231 |
+
id_names: An iterable of feature names that should be collected.
|
| 232 |
+
|
| 233 |
+
Returns:
|
| 234 |
+
A dictionary with "embeddings" and `id_names` as keys.
|
| 235 |
+
"""
|
| 236 |
+
ns = []
|
| 237 |
+
embeddings = []
|
| 238 |
+
ids = {id_name: [] for id_name in list(id_names) + ["mask"]}
|
| 239 |
+
|
| 240 |
+
t0 = time.time()
|
| 241 |
+
|
| 242 |
+
ds_b = input_pipeline.start_global(ds, self.devices)
|
| 243 |
+
for batch in ds_b:
|
| 244 |
+
ns.append(jax.device_get(self._count_p(batch["mask"])))
|
| 245 |
+
|
| 246 |
+
# Due to infinite padding, this loop will never end. We will stop once
|
| 247 |
+
# all processes only process padded data. We don't check the latest
|
| 248 |
+
# DeviceArray `ns[-1]` Because we want to keep our computation async for
|
| 249 |
+
# efficiency reasons.
|
| 250 |
+
if len(ns) >= 2 and ns[-2] == 0:
|
| 251 |
+
break
|
| 252 |
+
|
| 253 |
+
embs = embed_fn(train_state, batch[name])
|
| 254 |
+
if embed_fn not in self._compiled:
|
| 255 |
+
logging.info("Compiled %s embeddings in %.3fs", name, time.time() - t0)
|
| 256 |
+
t0 = time.time()
|
| 257 |
+
self._compiled.add(embed_fn)
|
| 258 |
+
|
| 259 |
+
embeddings.append(jax.device_get(embs))
|
| 260 |
+
for id_name in ids:
|
| 261 |
+
ids[id_name].append(jax.device_get(self._all_gather_p(batch[id_name])))
|
| 262 |
+
|
| 263 |
+
# Only access DeviceArray at end of loop for better efficiency.
|
| 264 |
+
ns = np.array(ns)
|
| 265 |
+
embeddings = np.concatenate(embeddings)
|
| 266 |
+
ids = {k: np.concatenate(v) for k, v in ids.items()}
|
| 267 |
+
masks = ids.pop("mask").astype(bool)
|
| 268 |
+
logging.info("Processed %s in %d steps - ...%s", name, len(ns), ns[-10:])
|
| 269 |
+
n = ns.sum()
|
| 270 |
+
logging.info("Totalling %d %s in %.3fs", n, name, time.time() - t0)
|
| 271 |
+
return {
|
| 272 |
+
"embeddings": embeddings[masks],
|
| 273 |
+
**{k: v[masks] for k, v in ids.items()},
|
| 274 |
+
}
|
| 275 |
+
|
| 276 |
+
def evaluate(self, train_state):
|
| 277 |
+
"""Returns evaluation results."""
|
| 278 |
+
images = self._embed("image", train_state, self.ds_images,
|
| 279 |
+
self._embed_images_p, ("id",))
|
| 280 |
+
texts = self._embed("labels", train_state, self.ds_texts,
|
| 281 |
+
self._embed_texts_p, ("id", "caption_i"))
|
| 282 |
+
# Shapes: (nimg, emb) * (emb, ntxt) -> (nimg, ntxt)
|
| 283 |
+
similarities = np.dot(images["embeddings"], texts["embeddings"].T)
|
| 284 |
+
|
| 285 |
+
t0 = time.time()
|
| 286 |
+
id2img = {id_: i for i, id_ in enumerate(images["id"])}
|
| 287 |
+
text_image_correspondence = [id2img[id_] for id_ in texts["id"]]
|
| 288 |
+
img2txt = image_text_retrieval.image_to_text_retrieval_eval(
|
| 289 |
+
-similarities, text_image_correspondence)
|
| 290 |
+
txt2img = image_text_retrieval.text_to_image_retrieval_eval(
|
| 291 |
+
-similarities, text_image_correspondence)
|
| 292 |
+
logging.info("Computed retrieval metrics in %.3fs", time.time() - t0)
|
| 293 |
+
|
| 294 |
+
return dict(
|
| 295 |
+
images=images,
|
| 296 |
+
texts=texts,
|
| 297 |
+
img2txt=img2txt,
|
| 298 |
+
txt2img=txt2img,
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
def run(self, train_state):
|
| 302 |
+
"""Returns metrics."""
|
| 303 |
+
results = self.evaluate(train_state)
|
| 304 |
+
return [(f"{direction}_{k.lower()}", v)
|
| 305 |
+
for direction in ("img2txt", "txt2img")
|
| 306 |
+
for k, v in results[direction].items()]
|
Tipsomaly/model/big_vision/evaluators/proj/image_text/retrieval_test.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Tests for retrieval."""
|
| 16 |
+
|
| 17 |
+
from unittest import mock
|
| 18 |
+
|
| 19 |
+
from big_vision.evaluators.proj.image_text import retrieval
|
| 20 |
+
from big_vision.pp import ops_general # pylint: disable=unused-import
|
| 21 |
+
from big_vision.pp import ops_image # pylint: disable=unused-import
|
| 22 |
+
from big_vision.pp import registry
|
| 23 |
+
import chex
|
| 24 |
+
import flax.linen as nn
|
| 25 |
+
import jax
|
| 26 |
+
import jax.numpy as jnp
|
| 27 |
+
import tensorflow as tf
|
| 28 |
+
import tensorflow_datasets as tfds
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _get_test_texts2labels():
|
| 32 |
+
|
| 33 |
+
def pp(features):
|
| 34 |
+
features["labels"] = tf.strings.to_number(features["texts"])
|
| 35 |
+
return features
|
| 36 |
+
|
| 37 |
+
return pp
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _get_copy_from(**key_map):
|
| 41 |
+
|
| 42 |
+
def copy_from(d):
|
| 43 |
+
d = dict(d)
|
| 44 |
+
for k1, k2 in key_map.items():
|
| 45 |
+
d[k1] = d[k2]
|
| 46 |
+
return d
|
| 47 |
+
|
| 48 |
+
return copy_from
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class _Model(nn.Module):
|
| 52 |
+
|
| 53 |
+
@nn.compact
|
| 54 |
+
def __call__(self, image, texts):
|
| 55 |
+
self.param("x", lambda _: 0.)
|
| 56 |
+
|
| 57 |
+
def z(x):
|
| 58 |
+
if x is not None:
|
| 59 |
+
batch_size = len(x)
|
| 60 |
+
# Note that the returned vector is most similar with other vectors
|
| 61 |
+
# generated from the same underlying `x[:]`.
|
| 62 |
+
x = jnp.concatenate([100 * jnp.ones([batch_size, 1]), x[:, None]],
|
| 63 |
+
axis=1)
|
| 64 |
+
return x / jnp.linalg.norm(x, axis=1)[:, None]
|
| 65 |
+
|
| 66 |
+
return z(image), z(texts), None
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def setUpModule():
|
| 70 |
+
chex.set_n_cpu_devices(8)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class RetrievalTest(tf.test.TestCase):
|
| 74 |
+
|
| 75 |
+
def test_prepare_datasets(self):
|
| 76 |
+
|
| 77 |
+
def generator():
|
| 78 |
+
yield {
|
| 79 |
+
"image": tf.ones([5, 5, 3], tf.float32),
|
| 80 |
+
"captions": {
|
| 81 |
+
"text": tf.constant(["11", "12"])
|
| 82 |
+
}
|
| 83 |
+
}
|
| 84 |
+
yield {
|
| 85 |
+
"image": tf.ones([4, 4, 3], tf.float32),
|
| 86 |
+
"captions": {
|
| 87 |
+
"text": tf.constant(["21", "22", "23"])
|
| 88 |
+
}
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
ds = tf.data.Dataset.from_generator(
|
| 92 |
+
generator,
|
| 93 |
+
output_signature={
|
| 94 |
+
"image": tf.TensorSpec(shape=[None, None, 3], dtype=tf.float32),
|
| 95 |
+
"captions": {
|
| 96 |
+
"text": tf.TensorSpec(shape=[None], dtype=tf.string),
|
| 97 |
+
},
|
| 98 |
+
})
|
| 99 |
+
with registry.temporary_ops(test_texts2labels=_get_test_texts2labels):
|
| 100 |
+
ds_img, ds_txt = retrieval.prepare_datasets(
|
| 101 |
+
ds,
|
| 102 |
+
pp_img="resize(2)",
|
| 103 |
+
pp_txt="test_texts2labels()",
|
| 104 |
+
txt_name=("captions", "text"),
|
| 105 |
+
)
|
| 106 |
+
it_img = iter(ds_img)
|
| 107 |
+
it_txt = iter(ds_txt)
|
| 108 |
+
batch = next(it_img)
|
| 109 |
+
self.assertAllEqual(batch["id"], 0)
|
| 110 |
+
self.assertAllEqual(batch["image"], tf.ones([2, 2, 3]))
|
| 111 |
+
batch = next(it_img)
|
| 112 |
+
self.assertAllEqual(batch["id"], 1)
|
| 113 |
+
self.assertAllEqual(batch["image"], tf.ones([2, 2, 3]))
|
| 114 |
+
batch = next(it_txt)
|
| 115 |
+
self.assertAllEqual(batch["id"], 0)
|
| 116 |
+
self.assertAllEqual(batch["caption_i"], 0)
|
| 117 |
+
self.assertAllEqual(batch["labels"], 11.0)
|
| 118 |
+
batch = next(it_txt)
|
| 119 |
+
self.assertAllEqual(batch["id"], 0)
|
| 120 |
+
self.assertAllEqual(batch["caption_i"], 1)
|
| 121 |
+
self.assertAllEqual(batch["labels"], 12.0)
|
| 122 |
+
batch = next(it_txt)
|
| 123 |
+
self.assertAllEqual(batch["id"], 1)
|
| 124 |
+
self.assertAllEqual(batch["caption_i"], 0)
|
| 125 |
+
self.assertAllEqual(batch["labels"], 21.0)
|
| 126 |
+
batch = next(it_txt)
|
| 127 |
+
self.assertAllEqual(batch["id"], 1)
|
| 128 |
+
self.assertAllEqual(batch["caption_i"], 1)
|
| 129 |
+
self.assertAllEqual(batch["labels"], 22.0)
|
| 130 |
+
batch = next(it_txt)
|
| 131 |
+
self.assertAllEqual(batch["id"], 1)
|
| 132 |
+
self.assertAllEqual(batch["caption_i"], 2)
|
| 133 |
+
self.assertAllEqual(batch["labels"], 23.0)
|
| 134 |
+
|
| 135 |
+
def test_evaluate(self):
|
| 136 |
+
per_device_batch_size = 2
|
| 137 |
+
batch_size = per_device_batch_size * jax.device_count()
|
| 138 |
+
num_examples = 1 * batch_size + 1
|
| 139 |
+
splits = {
|
| 140 |
+
"test":
|
| 141 |
+
tfds.core.SplitInfo(
|
| 142 |
+
name="test", shard_lengths=[num_examples], num_bytes=0)
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
model = _Model()
|
| 146 |
+
params = model.init(jax.random.PRNGKey(0), None, None)["params"]
|
| 147 |
+
|
| 148 |
+
with tfds.testing.mock_data(num_examples=num_examples):
|
| 149 |
+
info_mock = mock.Mock()
|
| 150 |
+
info_mock.splits = splits
|
| 151 |
+
with mock.patch.object(retrieval, "_get_dataset_info",
|
| 152 |
+
lambda _: info_mock):
|
| 153 |
+
with registry.temporary_ops(copy_from=_get_copy_from):
|
| 154 |
+
evaluator = retrieval.Evaluator(
|
| 155 |
+
lambda p, b: model.apply({"params": p},
|
| 156 |
+
b.get("image", None),
|
| 157 |
+
b.get("labels", None)),
|
| 158 |
+
dataset="coco_captions",
|
| 159 |
+
batch_size=batch_size,
|
| 160 |
+
devices=jax.devices(),
|
| 161 |
+
txt_name=("captions", "text"),
|
| 162 |
+
pp_img="copy_from(image='id')",
|
| 163 |
+
pp_txt="copy_from(labels='id')",
|
| 164 |
+
)
|
| 165 |
+
results = evaluator.evaluate(params)
|
| 166 |
+
|
| 167 |
+
# Assert all examples were processed.
|
| 168 |
+
self.assertLen(results["images"]["embeddings"], num_examples)
|
| 169 |
+
self.assertLen(results["images"]["id"], num_examples)
|
| 170 |
+
# Assert no padding was processed (expects exactly one (=first) image.id=0
|
| 171 |
+
self.assertEqual((results["images"]["id"] == 0).sum(), 1)
|
| 172 |
+
# Expect perfect ITR with above _Model()...
|
| 173 |
+
self.assertEqual(results["img2txt"]["Recall@1"], 1.0)
|
| 174 |
+
self.assertEqual(results["txt2img"]["Recall@5"], 1.0)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
if __name__ == "__main__":
|
| 178 |
+
tf.test.main()
|
Tipsomaly/model/big_vision/evaluators/proj/paligemma/perplexity.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Evaluator for perplexity of a model."""
|
| 16 |
+
import functools
|
| 17 |
+
|
| 18 |
+
from big_vision.evaluators import mean
|
| 19 |
+
import big_vision.utils as u
|
| 20 |
+
import jax.numpy as jnp
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# Temporary global flag to facilitate backwards compatability. Will be removed
|
| 24 |
+
# by the end of year 2023.
|
| 25 |
+
API = 'jit'
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# Cache the function such that it won't always recompile (in mean evaluator).
|
| 29 |
+
@functools.cache
|
| 30 |
+
def perplexity(
|
| 31 |
+
predict_fn, key='labels', shift_labels=True, pad_token=None):
|
| 32 |
+
"""Returns a function that computes perplexity."""
|
| 33 |
+
|
| 34 |
+
def _perplexity_fn(train_state, batch, **kw):
|
| 35 |
+
logits, _ = predict_fn(train_state, batch, **kw)
|
| 36 |
+
|
| 37 |
+
labels = batch[key]
|
| 38 |
+
weights = batch.get('mask_loss', jnp.ones_like(labels))
|
| 39 |
+
|
| 40 |
+
if pad_token is not None:
|
| 41 |
+
weights = weights * (labels != pad_token).astype(jnp.float32)
|
| 42 |
+
|
| 43 |
+
if shift_labels:
|
| 44 |
+
labels = labels[:, 1:]
|
| 45 |
+
weights = weights[:, 1:]
|
| 46 |
+
|
| 47 |
+
losses = u.weighted_softmax_xent(
|
| 48 |
+
logits=logits, labels=labels, weights=weights,
|
| 49 |
+
reduction=False, normalize=False)
|
| 50 |
+
normalizer = jnp.clip(weights.sum(axis=1), 2e-38)
|
| 51 |
+
|
| 52 |
+
return {'sum': losses, 'avg': losses / normalizer}
|
| 53 |
+
return _perplexity_fn
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class Evaluator(mean.Evaluator):
|
| 57 |
+
"""Perplexity evaluator."""
|
| 58 |
+
|
| 59 |
+
def __init__(self, predict_fn, *a,
|
| 60 |
+
key='labels', shift_labels=False, pad_token=None, **kw):
|
| 61 |
+
kw.setdefault('prefetch', 0) # More memory-saving default.
|
| 62 |
+
super().__init__(
|
| 63 |
+
perplexity(predict_fn, key, shift_labels, pad_token), *a, **kw)
|
Tipsomaly/model/big_vision/evaluators/proj/paligemma/transfers/chartqa.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Evaluator for ChartQA variants."""
|
| 16 |
+
|
| 17 |
+
import functools
|
| 18 |
+
|
| 19 |
+
import big_vision.evaluators.common as c
|
| 20 |
+
import big_vision.pp.tokenizer
|
| 21 |
+
import big_vision.utils as u
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# Temporary global flag to facilitate backwards compatability. Will be removed
|
| 25 |
+
# by the end of year 2023.
|
| 26 |
+
API = "jit"
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class Evaluator:
|
| 30 |
+
"""Evaluator for simple VQA tasks."""
|
| 31 |
+
|
| 32 |
+
def __init__(
|
| 33 |
+
self, predict_fn, tokenizer, to_lower=False,
|
| 34 |
+
outfile="{workdir}/{split}.json",
|
| 35 |
+
out_question_key="question_id", out_answer_key="answer",
|
| 36 |
+
*, data, devices, **kw):
|
| 37 |
+
self.get_data_iter, self.steps = c.eval_input_pipeline(
|
| 38 |
+
keep_on_cpu={"answer", "question_id"}, data=data, devices=devices, **kw)
|
| 39 |
+
|
| 40 |
+
self.outfile = c.resolve_outfile(outfile, split=data.get("split"))
|
| 41 |
+
self.out_question_key = out_question_key
|
| 42 |
+
self.out_answer_key = out_answer_key
|
| 43 |
+
|
| 44 |
+
# We'll need the tokenizer to detokenize the model outputs later.
|
| 45 |
+
self.tok = big_vision.pp.tokenizer.get_tokenizer(tokenizer)
|
| 46 |
+
self.postproc = (lambda s: s.lower()) if to_lower else lambda s: s
|
| 47 |
+
self.decode = functools.partial(
|
| 48 |
+
predict_fn, devices=devices, eos_token=self.tok.eos_token)
|
| 49 |
+
|
| 50 |
+
def run(self, train_state):
|
| 51 |
+
"""Does one evaluation run, yields metrics."""
|
| 52 |
+
|
| 53 |
+
accuracies = []
|
| 54 |
+
relaxed_accuracies = []
|
| 55 |
+
json_out = []
|
| 56 |
+
for _, batch in zip(range(self.steps), self.get_data_iter()):
|
| 57 |
+
# (batch, seqlen) array of decoded generated tokens.
|
| 58 |
+
tokens = self.decode(train_state, batch)
|
| 59 |
+
|
| 60 |
+
# (local_batch,) that indicates padding examples (0) vs real examples (1).
|
| 61 |
+
tokens = u.get_local_slice_from_fsarray(tokens)
|
| 62 |
+
ex_masks = u.get_local_slice_from_fsarray(batch["_mask"])
|
| 63 |
+
|
| 64 |
+
# Turn predictions into texts and then scores, one by one.
|
| 65 |
+
for i in range(len(tokens)):
|
| 66 |
+
if ex_masks[i] == 0: # Skip last-batch padding examples
|
| 67 |
+
continue
|
| 68 |
+
|
| 69 |
+
answer = self.postproc(self.tok.to_str(tokens[i], stop_at_eos=True))
|
| 70 |
+
|
| 71 |
+
gt = self.postproc(batch["answer"][i])
|
| 72 |
+
accuracies.append(float(answer == gt))
|
| 73 |
+
relaxed_accuracies.append(_relaxed_match(gt, answer))
|
| 74 |
+
json_out.append({
|
| 75 |
+
self.out_question_key: batch["question_id"][i].item(),
|
| 76 |
+
self.out_answer_key: answer,
|
| 77 |
+
"gt": gt,
|
| 78 |
+
"relaxed_match": relaxed_accuracies[-1],
|
| 79 |
+
})
|
| 80 |
+
|
| 81 |
+
# At this point `accuracies` is a list of per-example scores. However,
|
| 82 |
+
# remember that each host holds a different subset of the examples! So if
|
| 83 |
+
# we were to just return the mean accuracy here, we would effectively only
|
| 84 |
+
# have evaluated on the main host's (who writes metrics) subset!
|
| 85 |
+
# So now, we need to compute global means.
|
| 86 |
+
# There is one more caveat: `process_sum` needs the summands on each host
|
| 87 |
+
# to have the same size. So we either need to include dummy values for
|
| 88 |
+
# the padding examples (last batch, annoying), or we only sum scalars as in
|
| 89 |
+
# sufficient statistics, which we do here.
|
| 90 |
+
sum_accs, sum_relaxed_accs, num = c.process_sum(
|
| 91 |
+
[sum(accuracies), sum(relaxed_accuracies), len(accuracies)])
|
| 92 |
+
|
| 93 |
+
# Yielding metric_name, value means logging the metric.
|
| 94 |
+
yield "acc", sum_accs / num
|
| 95 |
+
yield "relaxed_acc", sum_relaxed_accs / num
|
| 96 |
+
yield "num", num # Just for sanity checks.
|
| 97 |
+
c.multiprocess_write_json(self.outfile, json_out)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def _to_float(text: str) -> float | None:
|
| 101 |
+
try:
|
| 102 |
+
if text.endswith("%"):
|
| 103 |
+
# Convert percentages to floats.
|
| 104 |
+
return float(text.rstrip("%")) / 100.0
|
| 105 |
+
else:
|
| 106 |
+
return float(text)
|
| 107 |
+
except ValueError:
|
| 108 |
+
return None
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def _relaxed_match(target: str,
|
| 112 |
+
prediction: str,
|
| 113 |
+
max_relative_error: float = 0.05) -> bool:
|
| 114 |
+
"""Calculates relaxed correctness.
|
| 115 |
+
|
| 116 |
+
The correctness tolerates certain error ratio defined by max_relative_error.
|
| 117 |
+
See https://arxiv.org/pdf/2203.10244.pdf, end of section 5.1:
|
| 118 |
+
“Following Methani et al. (2020), we use a relaxed accuracy measure for the
|
| 119 |
+
numeric answers to allow a minor inaccuracy that may result from the automatic
|
| 120 |
+
data extraction process. We consider an answer to be correct if it is within
|
| 121 |
+
5% of the gold answer. For non-numeric answers, we still need an exact match
|
| 122 |
+
to consider an answer to be correct.”
|
| 123 |
+
|
| 124 |
+
Args:
|
| 125 |
+
target: Target string.
|
| 126 |
+
prediction: Predicted string.
|
| 127 |
+
max_relative_error: Maximum relative error.
|
| 128 |
+
|
| 129 |
+
Returns:
|
| 130 |
+
Whether the prediction was correct given the specified tolerance.
|
| 131 |
+
"""
|
| 132 |
+
prediction_float = _to_float(prediction)
|
| 133 |
+
target_float = _to_float(target)
|
| 134 |
+
# When the target is 0 is always required an exact match.
|
| 135 |
+
if prediction_float is not None and target_float:
|
| 136 |
+
relative_error = abs(prediction_float - target_float) / abs(target_float)
|
| 137 |
+
return relative_error <= max_relative_error
|
| 138 |
+
else:
|
| 139 |
+
return prediction == target
|
Tipsomaly/model/big_vision/evaluators/proj/paligemma/transfers/pope.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Evaluator for the POPE dataset (https://github.com/RUCAIBox/POPE).
|
| 16 |
+
|
| 17 |
+
POPE is a binary classification dataset with ground-truth answers being either
|
| 18 |
+
'yes' or 'no'.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
import functools
|
| 22 |
+
|
| 23 |
+
import big_vision.datasets.core
|
| 24 |
+
import big_vision.evaluators.common as c
|
| 25 |
+
import big_vision.input_pipeline
|
| 26 |
+
import big_vision.pp.builder
|
| 27 |
+
import big_vision.pp.tokenizer
|
| 28 |
+
import big_vision.utils as u
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# Temporary global flag to facilitate backwards compatability. Will be removed
|
| 32 |
+
# by the end of year 2023.
|
| 33 |
+
API = "jit"
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class Evaluator:
|
| 37 |
+
"""Evaluator for the POPE task.
|
| 38 |
+
|
| 39 |
+
This evaluator expects the batch to contain a field `question_id` and a field
|
| 40 |
+
`answer` for single ground truth or `answers` for multiple ground truths.
|
| 41 |
+
|
| 42 |
+
The field names used when writting the json result can be controlled with
|
| 43 |
+
`out_question_key` and `out_answer_key`.
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
def __init__(
|
| 47 |
+
self,
|
| 48 |
+
predict_fn,
|
| 49 |
+
data,
|
| 50 |
+
pp_fn,
|
| 51 |
+
tokenizer,
|
| 52 |
+
batch_size,
|
| 53 |
+
*,
|
| 54 |
+
devices,
|
| 55 |
+
outfile="{workdir}/{split}.json",
|
| 56 |
+
out_question_key="question_id",
|
| 57 |
+
out_answer_key="answer"
|
| 58 |
+
):
|
| 59 |
+
|
| 60 |
+
self.outfile = c.resolve_outfile(outfile, split=data.get("split"))
|
| 61 |
+
self.out_question_key = out_question_key
|
| 62 |
+
self.out_answer_key = out_answer_key
|
| 63 |
+
# This will mostly look the same across all evaluators, preparing data:
|
| 64 |
+
data = big_vision.datasets.core.get(**data)
|
| 65 |
+
pp_fn = big_vision.pp.builder.get_preprocess_fn(pp_fn)
|
| 66 |
+
self.ds, self.steps = big_vision.input_pipeline.make_for_inference(
|
| 67 |
+
data.get_tfdata(ordered=True),
|
| 68 |
+
pp_fn,
|
| 69 |
+
batch_size,
|
| 70 |
+
num_ex_per_process=data.num_examples_per_process(),
|
| 71 |
+
)
|
| 72 |
+
# The `keep_on_cpu=` argument lists the data keys that, if they exist, we
|
| 73 |
+
# do NOT want to ship to the TPUs and instead just keep in host memory.
|
| 74 |
+
# Typically ground-truth and metadata, that is often of string type.
|
| 75 |
+
self.data_iter = big_vision.input_pipeline.start_global(
|
| 76 |
+
self.ds, devices, keep_on_cpu={"answer", "question_id"}
|
| 77 |
+
)
|
| 78 |
+
# We'll need the tokenizer to detokenize the model outputs later.
|
| 79 |
+
self.tok = big_vision.pp.tokenizer.get_tokenizer(tokenizer)
|
| 80 |
+
self.decode = functools.partial(
|
| 81 |
+
predict_fn, devices=devices, eos_token=self.tok.eos_token
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
def run(self, train_state):
|
| 85 |
+
"""Does one evaluation run, yields metrics."""
|
| 86 |
+
|
| 87 |
+
accuracies = []
|
| 88 |
+
valid = []
|
| 89 |
+
json_out = []
|
| 90 |
+
for _, batch in zip(range(self.steps), self.data_iter):
|
| 91 |
+
# (batch, seqlen) array of decoded generated tokens.
|
| 92 |
+
tokens = self.decode(train_state, batch)
|
| 93 |
+
|
| 94 |
+
# (local_batch,) that indicates padding examples (0) vs real examples (1).
|
| 95 |
+
tokens = u.get_local_slice_from_fsarray(tokens)
|
| 96 |
+
ex_masks = u.get_local_slice_from_fsarray(batch["_mask"])
|
| 97 |
+
|
| 98 |
+
# Turn predictions into texts and then scores, one by one.
|
| 99 |
+
for i in range(len(tokens)):
|
| 100 |
+
if ex_masks[i] == 0: # Skip last-batch padding examples
|
| 101 |
+
continue
|
| 102 |
+
|
| 103 |
+
answer = self.tok.to_str(tokens[i], stop_at_eos=True).lower()
|
| 104 |
+
gt = batch["answer"][i]
|
| 105 |
+
accuracies.append(float(answer == gt))
|
| 106 |
+
valid.append(float(answer in ("yes", "no")))
|
| 107 |
+
|
| 108 |
+
json_out.append(
|
| 109 |
+
{
|
| 110 |
+
self.out_question_key: batch["question_id"][i].item(),
|
| 111 |
+
self.out_answer_key: answer,
|
| 112 |
+
}
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
# At this point `accuracies` is a list of per-example scores. However,
|
| 116 |
+
# remember that each host holds a different subset of the examples! So if
|
| 117 |
+
# we were to just return the mean accuracy here, we would effectively only
|
| 118 |
+
# have evaluated on the main host's (who writes metrics) subset!
|
| 119 |
+
# So now, we need to compute global means.
|
| 120 |
+
# There is one more caveat: `process_sum` needs the summands on each host
|
| 121 |
+
# to have the same size. So we either need to include dummy values for
|
| 122 |
+
# the padding examples (last batch, annoying), or we only sum scalars as in
|
| 123 |
+
# sufficient statistics, which we do here.
|
| 124 |
+
sum_accs, sum_valid, num = c.process_sum([
|
| 125 |
+
sum(accuracies),
|
| 126 |
+
sum(valid),
|
| 127 |
+
len(accuracies),
|
| 128 |
+
])
|
| 129 |
+
|
| 130 |
+
if num:
|
| 131 |
+
yield "acc", sum_accs / num
|
| 132 |
+
yield "valid_percent", sum_valid / num
|
| 133 |
+
yield "num", num
|
| 134 |
+
|
| 135 |
+
c.multiprocess_write_json(self.outfile, json_out)
|
Tipsomaly/model/big_vision/evaluators/proj/uvim/coco_panoptic.py
ADDED
|
@@ -0,0 +1,324 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""COCO17 panoptic evaluation."""
|
| 16 |
+
import functools
|
| 17 |
+
from functools import partial
|
| 18 |
+
import json
|
| 19 |
+
import os
|
| 20 |
+
import tempfile
|
| 21 |
+
import time
|
| 22 |
+
import zipfile
|
| 23 |
+
|
| 24 |
+
from absl import logging
|
| 25 |
+
from big_vision.evaluators.proj.uvim import common
|
| 26 |
+
import big_vision.pp.builder as pp_builder
|
| 27 |
+
import jax
|
| 28 |
+
import numpy as np
|
| 29 |
+
import panopticapi_converters.twochannels2panoptic_coco_format as converter
|
| 30 |
+
from panopticapi.evaluation import pq_compute
|
| 31 |
+
import tensorflow as tf
|
| 32 |
+
import tensorflow_datasets as tfds
|
| 33 |
+
|
| 34 |
+
from tensorflow.io import gfile
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
ROOT = os.environ.get('COCO_DATA_DIR', '.')
|
| 38 |
+
|
| 39 |
+
PANOPTIC_COCO_CATS_FILE = f'{ROOT}/panoptic_coco_categories.json'
|
| 40 |
+
PANOPTIC_2017 = {
|
| 41 |
+
'train': f'{ROOT}/panoptic_train2017.json',
|
| 42 |
+
'validation': f'{ROOT}/panoptic_val2017.json',
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
PANOPTIC_GT_ZIP = {
|
| 46 |
+
'train': f'{ROOT}/panoptic_train2017.zip',
|
| 47 |
+
'validation': f'{ROOT}/panoptic_val2017.zip',
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class Evaluator:
|
| 52 |
+
"""Panoptic segmentation evaluator: calls official COCO API.
|
| 53 |
+
|
| 54 |
+
`predict_fn` accepts arbitrary dictionaries of parameters and data, where
|
| 55 |
+
the data dictionary is produced by the `pp` op. It is expected to output a
|
| 56 |
+
2-channel mask, where the first channel encodes semantics, and the second
|
| 57 |
+
channel encodes instance ids.
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
def __init__(self,
|
| 61 |
+
predict_fn,
|
| 62 |
+
pp_fn,
|
| 63 |
+
batch_size,
|
| 64 |
+
dataset='coco/2017_panoptic',
|
| 65 |
+
dataset_dir=None,
|
| 66 |
+
split='validation',
|
| 67 |
+
predict_kwargs=None):
|
| 68 |
+
# Prepare to run predict on all processes and gather predictions on all
|
| 69 |
+
# devices. Note: if needed consider only gather across processes.
|
| 70 |
+
def predict(params, batch):
|
| 71 |
+
res = {
|
| 72 |
+
'image/id': batch['image/id'],
|
| 73 |
+
'mask': batch['mask'],
|
| 74 |
+
'y': predict_fn(params, batch['input'], **(predict_kwargs or {})),
|
| 75 |
+
}
|
| 76 |
+
return jax.lax.all_gather(res, axis_name='data', axis=0)
|
| 77 |
+
|
| 78 |
+
self.predict_fn = jax.pmap(predict, axis_name='data')
|
| 79 |
+
|
| 80 |
+
# Prepare data for each process and pad with zeros so all processes have the
|
| 81 |
+
# same number of batches.
|
| 82 |
+
def preprocess(example):
|
| 83 |
+
return {
|
| 84 |
+
'image/id': example['image/id'],
|
| 85 |
+
'mask': tf.constant(1),
|
| 86 |
+
'input': pp_builder.get_preprocess_fn(pp_fn)(example),
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
self.data = common.get_jax_process_dataset(
|
| 90 |
+
dataset, split, dataset_dir=dataset_dir,
|
| 91 |
+
global_batch_size=batch_size,
|
| 92 |
+
pp_fn=preprocess)
|
| 93 |
+
|
| 94 |
+
# Only process 0 runs conversion to png and calls into coco api.
|
| 95 |
+
if jax.process_index() == 0:
|
| 96 |
+
self.result_dir = tempfile.TemporaryDirectory()
|
| 97 |
+
(self.gt_folder, self.gt_json, self.categories_json,
|
| 98 |
+
self.remap, self.size_map) = _prepare_ground_truth(
|
| 99 |
+
dataset, split, dataset_dir)
|
| 100 |
+
|
| 101 |
+
def _compute_png_predictions(self, params):
|
| 102 |
+
"""Computes predictions and converts then to png to optimize memory use."""
|
| 103 |
+
count = 0
|
| 104 |
+
logging.info('Panoptic eval: running inference.')
|
| 105 |
+
for batch in self.data.as_numpy_iterator():
|
| 106 |
+
out = self.predict_fn(params, batch)
|
| 107 |
+
|
| 108 |
+
if jax.process_index():
|
| 109 |
+
continue
|
| 110 |
+
|
| 111 |
+
out = jax.device_get(jax.tree_map(lambda x: x[0], out))
|
| 112 |
+
mask = out['mask']
|
| 113 |
+
pan_recs = out['y'][mask != 0]
|
| 114 |
+
ids = out['image/id'][mask != 0]
|
| 115 |
+
|
| 116 |
+
for pan_rec, image_id in zip(pan_recs, ids):
|
| 117 |
+
sem = pan_rec[..., 0]
|
| 118 |
+
ins = pan_rec[..., 1]
|
| 119 |
+
|
| 120 |
+
sem_remapped = np.array(sem)
|
| 121 |
+
for v in np.unique(sem):
|
| 122 |
+
sem_remapped[sem == v] = self.remap[v]
|
| 123 |
+
sem = sem_remapped
|
| 124 |
+
|
| 125 |
+
pan_mask = np.stack([sem, ins, np.zeros_like(sem)], axis=-1)
|
| 126 |
+
pan_mask = _resize_nearest(pan_mask, self.size_map[image_id])
|
| 127 |
+
pan_mask_png = tf.io.encode_png(pan_mask.astype('uint8')).numpy()
|
| 128 |
+
|
| 129 |
+
fname = f'{self.result_dir.name}/{image_id:012d}.png'
|
| 130 |
+
with open(fname, 'wb') as f:
|
| 131 |
+
f.write(pan_mask_png)
|
| 132 |
+
count += 1
|
| 133 |
+
|
| 134 |
+
logging.log_every_n_seconds(
|
| 135 |
+
logging.INFO, 'Panoptic eval: processed %i examples so far.', 30,
|
| 136 |
+
count)
|
| 137 |
+
|
| 138 |
+
if jax.process_index():
|
| 139 |
+
return None
|
| 140 |
+
|
| 141 |
+
logging.info('Panoptic eval: inference done. Processed %d examples.', count)
|
| 142 |
+
return self.result_dir
|
| 143 |
+
|
| 144 |
+
def run(self, params):
|
| 145 |
+
"""Run eval."""
|
| 146 |
+
# Note result_dir is constant, but files inside are mutated.
|
| 147 |
+
result_dir = self._compute_png_predictions(params)
|
| 148 |
+
|
| 149 |
+
if not result_dir:
|
| 150 |
+
return
|
| 151 |
+
|
| 152 |
+
with tempfile.TemporaryDirectory() as pred_folder, \
|
| 153 |
+
tempfile.NamedTemporaryFile(mode='w') as pred_json:
|
| 154 |
+
|
| 155 |
+
logging.info('Panoptic eval: running conversion.')
|
| 156 |
+
converter.converter(
|
| 157 |
+
source_folder=result_dir.name,
|
| 158 |
+
images_json_file=self.gt_json,
|
| 159 |
+
categories_json_file=self.categories_json,
|
| 160 |
+
segmentations_folder=pred_folder,
|
| 161 |
+
predictions_json_file=pred_json.name)
|
| 162 |
+
logging.info('Panoptic eval: conversion done.')
|
| 163 |
+
|
| 164 |
+
logging.info('Panoptic eval: running metrics computation.')
|
| 165 |
+
res = pq_compute(gt_json_file=self.gt_json,
|
| 166 |
+
gt_folder=self.gt_folder,
|
| 167 |
+
pred_json_file=pred_json.name,
|
| 168 |
+
pred_folder=pred_folder)
|
| 169 |
+
logging.info('Panoptic eval: metrics computation done.')
|
| 170 |
+
|
| 171 |
+
for k in ['All', 'Stuff', 'Things']:
|
| 172 |
+
for m in ['pq', 'rq', 'sq']:
|
| 173 |
+
yield f'{k}_{m}', res[k][m]
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def _prepare_ground_truth(dataset, split, data_dir):
|
| 177 |
+
"""Prepare ground truth from tf.data.Dataset."""
|
| 178 |
+
if dataset == 'coco/2017_panoptic' and data_dir is None:
|
| 179 |
+
return _prepare_ground_truth_from_zipfiles(split)
|
| 180 |
+
else:
|
| 181 |
+
return _prepare_ground_truth_from_dataset(dataset, split, data_dir)
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
@functools.lru_cache(maxsize=None)
|
| 185 |
+
def _prepare_ground_truth_from_dataset(dataset, split, data_dir):
|
| 186 |
+
"""Prepare ground truth from a tf.data.Dataset."""
|
| 187 |
+
dataset = tfds.builder(dataset, data_dir=data_dir).as_dataset(split=split)
|
| 188 |
+
|
| 189 |
+
categories_json = _make_local_copy(PANOPTIC_COCO_CATS_FILE)
|
| 190 |
+
with gfile.GFile(categories_json, 'rb') as f:
|
| 191 |
+
categories = json.loads(f.read())
|
| 192 |
+
|
| 193 |
+
# Build map from tfds class ids to COCO class ids.
|
| 194 |
+
remap = {0: 0}
|
| 195 |
+
with gfile.GFile(categories_json, 'r') as f:
|
| 196 |
+
remap = {**remap, **{(i + 1): x['id'] for i, x in enumerate(categories)}}
|
| 197 |
+
|
| 198 |
+
gt_folder = tempfile.mkdtemp()
|
| 199 |
+
gfile.makedirs(gt_folder)
|
| 200 |
+
size_map = {}
|
| 201 |
+
annotations = []
|
| 202 |
+
images = []
|
| 203 |
+
for example in dataset:
|
| 204 |
+
image_id = int(example['image/id'])
|
| 205 |
+
panoptic_image = example['panoptic_image']
|
| 206 |
+
ann_ids = example['panoptic_objects']['id']
|
| 207 |
+
ann_labels = example['panoptic_objects']['label']
|
| 208 |
+
ann_iscrowd = example['panoptic_objects']['is_crowd']
|
| 209 |
+
ann_area = example['panoptic_objects']['area']
|
| 210 |
+
|
| 211 |
+
fname = f'{image_id:012d}.png'
|
| 212 |
+
with gfile.GFile(os.path.join(gt_folder, fname), 'wb') as f:
|
| 213 |
+
f.write(tf.io.encode_png(panoptic_image).numpy())
|
| 214 |
+
|
| 215 |
+
size_map[image_id] = (panoptic_image.shape[0], panoptic_image.shape[1])
|
| 216 |
+
|
| 217 |
+
segments_info = []
|
| 218 |
+
for i in range(len(ann_ids)):
|
| 219 |
+
segments_info.append({
|
| 220 |
+
'id': int(ann_ids[i]),
|
| 221 |
+
'category_id': remap[int(ann_labels[i] + 1)],
|
| 222 |
+
'iscrowd': int(ann_iscrowd[i]),
|
| 223 |
+
'area': int(ann_area[i]),
|
| 224 |
+
})
|
| 225 |
+
|
| 226 |
+
annotations.append({
|
| 227 |
+
'file_name': str(fname),
|
| 228 |
+
'image_id': int(image_id),
|
| 229 |
+
'segments_info': segments_info
|
| 230 |
+
})
|
| 231 |
+
images.append({
|
| 232 |
+
'id': image_id,
|
| 233 |
+
'file_name': f'{image_id:012d}.jpg',
|
| 234 |
+
})
|
| 235 |
+
|
| 236 |
+
# Write annotations.json needed for pq_compute.
|
| 237 |
+
gt_json = os.path.join(gt_folder, 'annotations.json')
|
| 238 |
+
with gfile.GFile(gt_json, 'wb') as f:
|
| 239 |
+
f.write(json.dumps({
|
| 240 |
+
'images': images,
|
| 241 |
+
'annotations': annotations,
|
| 242 |
+
'categories': categories,
|
| 243 |
+
}))
|
| 244 |
+
|
| 245 |
+
return gt_folder, gt_json, categories_json, remap, size_map
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def _prepare_ground_truth_from_zipfiles(split):
|
| 249 |
+
"""Prepare ground truth from coco zip files."""
|
| 250 |
+
split_prefix = split.split('[')[0]
|
| 251 |
+
if split_prefix not in ('train', 'validation'):
|
| 252 |
+
raise ValueError(f'Split {split} not supported')
|
| 253 |
+
|
| 254 |
+
# The following 4 calls are cached. This allows to save significant time
|
| 255 |
+
# in use cases like sweeping predict_fn hparams on the same run.
|
| 256 |
+
gt_json = _make_local_copy(PANOPTIC_2017[split_prefix])
|
| 257 |
+
gt_folder = _make_local_unzip_copy(PANOPTIC_GT_ZIP[split_prefix])
|
| 258 |
+
categories_json = _make_local_copy(PANOPTIC_COCO_CATS_FILE)
|
| 259 |
+
image_ids = _list_image_ids('coco/2017_panoptic', split)
|
| 260 |
+
|
| 261 |
+
gt_folder = os.path.join(
|
| 262 |
+
gt_folder, 'panoptic_val2017'
|
| 263 |
+
if split_prefix == 'validation' else 'panoptic_train2017')
|
| 264 |
+
|
| 265 |
+
# Build map from tfds class ids to COCO class ids.
|
| 266 |
+
remap = {0: 0}
|
| 267 |
+
with gfile.GFile(categories_json, 'r') as f:
|
| 268 |
+
remap = {**remap, **{(i + 1): x['id'] for i, x in enumerate(json.load(f))}}
|
| 269 |
+
|
| 270 |
+
# Filters gt_json to contain only annotations for images in dataset.
|
| 271 |
+
with gfile.GFile(gt_json) as f:
|
| 272 |
+
data = json.load(f)
|
| 273 |
+
logging.info(
|
| 274 |
+
'Panoptic eval: pre-filter %d annotations.',
|
| 275 |
+
len(data['annotations'])
|
| 276 |
+
)
|
| 277 |
+
data['images'] = [x for x in data['images'] if x['id'] in image_ids]
|
| 278 |
+
data['annotations'] = [
|
| 279 |
+
x for x in data['annotations'] if x['image_id'] in image_ids
|
| 280 |
+
]
|
| 281 |
+
logging.info(
|
| 282 |
+
'Panoptic eval: post-filter %d annotations.',
|
| 283 |
+
len(data['annotations'])
|
| 284 |
+
)
|
| 285 |
+
filtered_gt_json = tempfile.NamedTemporaryFile(delete=False).name
|
| 286 |
+
with open(filtered_gt_json, 'w') as f:
|
| 287 |
+
json.dump(data, f)
|
| 288 |
+
|
| 289 |
+
# Precompute images sizes.
|
| 290 |
+
size_map = {x['id']: (x['height'], x['width']) for x in data['images']}
|
| 291 |
+
|
| 292 |
+
return gt_folder, filtered_gt_json, categories_json, remap, size_map
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
@functools.lru_cache(maxsize=None)
|
| 296 |
+
def _list_image_ids(dataset, split):
|
| 297 |
+
d = tfds.load(dataset, split=split).map(lambda x: x['image/id'])
|
| 298 |
+
return frozenset(d.as_numpy_iterator())
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
@functools.lru_cache(maxsize=None)
|
| 302 |
+
def _make_local_copy(fname) -> str:
|
| 303 |
+
start = time.monotonic()
|
| 304 |
+
local_file = tempfile.NamedTemporaryFile(delete=False)
|
| 305 |
+
gfile.copy(fname, local_file.name, overwrite=True)
|
| 306 |
+
logging.info('Copy %s in %d seconds.', fname, time.monotonic() - start)
|
| 307 |
+
return local_file.name
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
@functools.lru_cache(maxsize=None)
|
| 311 |
+
def _make_local_unzip_copy(fname) -> str:
|
| 312 |
+
start = time.monotonic()
|
| 313 |
+
folder = tempfile.mkdtemp()
|
| 314 |
+
with tempfile.NamedTemporaryFile() as tmp_zip_file:
|
| 315 |
+
gfile.copy(fname, tmp_zip_file.name, overwrite=True)
|
| 316 |
+
with zipfile.ZipFile(tmp_zip_file.name, 'r') as f:
|
| 317 |
+
f.extractall(folder)
|
| 318 |
+
logging.info('Copy %s in %d seconds.', fname, time.monotonic() - start)
|
| 319 |
+
return folder
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
@partial(jax.jit, static_argnums=(1,), backend='cpu')
|
| 323 |
+
def _resize_nearest(image, shape):
|
| 324 |
+
return jax.image.resize(image, shape + image.shape[-1:], 'nearest')
|
Tipsomaly/model/big_vision/evaluators/proj/uvim/coltran_fid.py
ADDED
|
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Evaluation producing ColTran FID-5K metric."""
|
| 16 |
+
|
| 17 |
+
import functools
|
| 18 |
+
import os
|
| 19 |
+
|
| 20 |
+
from absl import logging
|
| 21 |
+
import einops
|
| 22 |
+
import jax
|
| 23 |
+
import numpy as np
|
| 24 |
+
import tensorflow as tf
|
| 25 |
+
import tensorflow_datasets as tfds
|
| 26 |
+
import tensorflow_gan as tfgan
|
| 27 |
+
import tensorflow_hub as tfhub
|
| 28 |
+
|
| 29 |
+
from tensorflow.io import gfile
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
ROOT = os.environ.get("FID_DATA_DIR", ".")
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _preprocess(image, resolution=512):
|
| 36 |
+
"""ColTran dataset preprocessing.
|
| 37 |
+
|
| 38 |
+
See,
|
| 39 |
+
github.com/google-research/google-research/blob/master/coltran/datasets.py#L44
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
image: ImageNet example from TFDS.
|
| 43 |
+
resolution: Integer representing output size.
|
| 44 |
+
|
| 45 |
+
Returns:
|
| 46 |
+
An int32 image of size (resolution, resolution, 3).
|
| 47 |
+
"""
|
| 48 |
+
image_shape = tf.shape(image)
|
| 49 |
+
height, width = image_shape[0], image_shape[1]
|
| 50 |
+
side_size = tf.minimum(height, width)
|
| 51 |
+
image = tf.image.resize_with_crop_or_pad(
|
| 52 |
+
image, target_height=side_size, target_width=side_size)
|
| 53 |
+
image = tf.image.resize(image, method="area", antialias=True,
|
| 54 |
+
size=(resolution, resolution))
|
| 55 |
+
image = tf.cast(tf.round(image), dtype=tf.int32)
|
| 56 |
+
return image
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def _normalize(x):
|
| 60 |
+
"""Coltran normalization to expected range for Inception module.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
x: Image with values in [0,255].
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
Image with values in [-1,1].
|
| 67 |
+
"""
|
| 68 |
+
x = tf.cast(x, tf.float32)
|
| 69 |
+
x = (x / 128.0) - 1.0 # note: 128.0 is the value used in ColTran.
|
| 70 |
+
return x
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class Evaluator:
|
| 74 |
+
"""ColTran FID-5K Evaluator.
|
| 75 |
+
|
| 76 |
+
This Evaluator aims to mirror the evaluation pipeline used by Kumar et.al.
|
| 77 |
+
in Colorization Transformer (https://arxiv.org/abs/2102.04432).
|
| 78 |
+
|
| 79 |
+
To be clear: much of this code is direct snippets from ColTran code.
|
| 80 |
+
|
| 81 |
+
See,
|
| 82 |
+
github.com/google-research/google-research/blob/master/coltran/datasets.py#L44
|
| 83 |
+
|
| 84 |
+
The ColTran pipeline has numerous stages, where serialied data is passed
|
| 85 |
+
between binaries via file, etc... While we don't physically write the same
|
| 86 |
+
files, we simulate the effects of the serialization (e.g., quantization).
|
| 87 |
+
"""
|
| 88 |
+
|
| 89 |
+
def __init__(self,
|
| 90 |
+
predict_fn,
|
| 91 |
+
batch_size, # ignored
|
| 92 |
+
device_batch_size=5,
|
| 93 |
+
coltran_seed=1,
|
| 94 |
+
predict_kwargs=None):
|
| 95 |
+
"""Create Evaluator.
|
| 96 |
+
|
| 97 |
+
Args:
|
| 98 |
+
predict_fn: Colorization prediction function. Expects grayscale images
|
| 99 |
+
of size (512, 512, 3) in keys `image` and `image_ctx` with values in
|
| 100 |
+
the range [-1,1]. Outputs `color` image in range [-1,1].
|
| 101 |
+
batch_size: ignored.
|
| 102 |
+
device_batch_size: number of images per batch, per device.
|
| 103 |
+
coltran_seed: used to specify the block of 5_000 images used to generate
|
| 104 |
+
the reference pool. Value of `1` matches default ColTran code.
|
| 105 |
+
predict_kwargs: arguments passed to `predict_fn`.
|
| 106 |
+
"""
|
| 107 |
+
del batch_size
|
| 108 |
+
|
| 109 |
+
self.num_devices = jax.local_device_count()
|
| 110 |
+
self.device_batch_size = device_batch_size
|
| 111 |
+
logging.log(logging.INFO, "Colorizing with batch size %i on %i devices.",
|
| 112 |
+
self.device_batch_size, self.num_devices)
|
| 113 |
+
assert 5_000 % (self.device_batch_size * self.num_devices) == 0
|
| 114 |
+
|
| 115 |
+
predict = functools.partial(predict_fn, **(predict_kwargs or {}))
|
| 116 |
+
self.predict_fn = jax.pmap(predict)
|
| 117 |
+
|
| 118 |
+
module = tfhub.load(tfgan.eval.INCEPTION_TFHUB)
|
| 119 |
+
def _pools(x):
|
| 120 |
+
return np.squeeze(module(x)[tfgan.eval.INCEPTION_FINAL_POOL].numpy())
|
| 121 |
+
|
| 122 |
+
self.inception_pool = _pools
|
| 123 |
+
|
| 124 |
+
# Setup the colorization dataset.
|
| 125 |
+
# TRICKY: ColTran FID-5k uses the first 5_000 images returned as read by
|
| 126 |
+
# default from tensorflow_datasets (that is: with shard interleaving).
|
| 127 |
+
# In particular note that it is different than the set of images returned
|
| 128 |
+
# by "validation[:5000]".
|
| 129 |
+
def _eval_data_preprocess(example):
|
| 130 |
+
# Colorization happens at 512x512 resolution.
|
| 131 |
+
image = _preprocess(example["image"], resolution=512)
|
| 132 |
+
image = _normalize(image)
|
| 133 |
+
grayscale = tf.image.grayscale_to_rgb(tf.image.rgb_to_grayscale(image))
|
| 134 |
+
return {
|
| 135 |
+
"image": image,
|
| 136 |
+
"grayscale": grayscale,
|
| 137 |
+
"file_name": example["file_name"]
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
ds = tfds.load("imagenet2012", split="validation")
|
| 141 |
+
ds = ds.map(_eval_data_preprocess)
|
| 142 |
+
ds = ds.take(5_000)
|
| 143 |
+
ds = ds.batch(self.device_batch_size)
|
| 144 |
+
ds = ds.batch(self.num_devices)
|
| 145 |
+
self.eval_data = ds.cache().prefetch(tf.data.AUTOTUNE)
|
| 146 |
+
|
| 147 |
+
# Setup the reference dataset.
|
| 148 |
+
def _reference_data_preprocess(example):
|
| 149 |
+
# ColTran eval operates on 256x256.
|
| 150 |
+
image = _preprocess(example["image"], resolution=256)
|
| 151 |
+
image = _normalize(image)
|
| 152 |
+
return {"image": image, "file_name": example["file_name"]}
|
| 153 |
+
|
| 154 |
+
ds = tfds.load("imagenet2012", split="validation")
|
| 155 |
+
ds = ds.map(_reference_data_preprocess)
|
| 156 |
+
# Skip the images used in colorization.
|
| 157 |
+
ds = ds.skip(5_000)
|
| 158 |
+
# ColTran eval w/ seed=1 effectively uses 10_000:15_000 to
|
| 159 |
+
# calculate reference.
|
| 160 |
+
ds = ds.skip(coltran_seed * 5_000)
|
| 161 |
+
ds = ds.take(5_000)
|
| 162 |
+
ds = ds.batch(device_batch_size)
|
| 163 |
+
self.reference_data = ds.cache().prefetch(tf.data.AUTOTUNE)
|
| 164 |
+
|
| 165 |
+
def _get_file(name):
|
| 166 |
+
return os.path.join(ROOT, name)
|
| 167 |
+
|
| 168 |
+
with gfile.GFile(_get_file("eval_file_names.txt")) as f:
|
| 169 |
+
self.eval_file_names = frozenset(f.read().splitlines())
|
| 170 |
+
|
| 171 |
+
with gfile.GFile(_get_file("reference_file_names.txt")) as f:
|
| 172 |
+
self.reference_file_names = frozenset(f.read().splitlines())
|
| 173 |
+
|
| 174 |
+
def run(self, params):
|
| 175 |
+
"""Run eval."""
|
| 176 |
+
|
| 177 |
+
if jax.process_index(): # Host0 does all work.
|
| 178 |
+
return
|
| 179 |
+
|
| 180 |
+
color_pools = []
|
| 181 |
+
color_file_names = set()
|
| 182 |
+
for i, batch in enumerate(self.eval_data.as_numpy_iterator()):
|
| 183 |
+
predict_batch = {
|
| 184 |
+
"labels": batch["image"],
|
| 185 |
+
"image": batch["grayscale"],
|
| 186 |
+
"image_ctx": batch["grayscale"],
|
| 187 |
+
}
|
| 188 |
+
y = self.predict_fn(params, predict_batch)
|
| 189 |
+
y = y["color"]
|
| 190 |
+
y = einops.rearrange(y, "d b h w c -> (d b) h w c")
|
| 191 |
+
|
| 192 |
+
# Return to the ColTran eval size of 256x256.
|
| 193 |
+
y = tf.image.resize(y, (256, 256), "area")
|
| 194 |
+
|
| 195 |
+
# Mimic effect of serializing image as integers and map back to [-1, 1].
|
| 196 |
+
y = np.clip(np.round((y + 1.) * 128.), 0, 255)
|
| 197 |
+
y = _normalize(y)
|
| 198 |
+
|
| 199 |
+
color_pools.append(self.inception_pool(y))
|
| 200 |
+
|
| 201 |
+
file_names = einops.rearrange(batch["file_name"], "d b -> (d b)")
|
| 202 |
+
color_file_names.update([f.decode() for f in file_names])
|
| 203 |
+
|
| 204 |
+
logging.log_every_n_seconds(
|
| 205 |
+
logging.INFO,
|
| 206 |
+
"ColTran FID eval: processed %i colorized examples so far.", 30,
|
| 207 |
+
(i + 1) * self.device_batch_size * self.num_devices)
|
| 208 |
+
|
| 209 |
+
reference_pools = []
|
| 210 |
+
reference_file_names = set()
|
| 211 |
+
for i, batch in enumerate(self.reference_data.as_numpy_iterator()):
|
| 212 |
+
image = batch["image"]
|
| 213 |
+
assert np.array_equal(image.shape, (self.device_batch_size, 256, 256, 3))
|
| 214 |
+
reference_pools.append(self.inception_pool(image))
|
| 215 |
+
reference_file_names.update([f.decode() for f in batch["file_name"]])
|
| 216 |
+
|
| 217 |
+
logging.log_every_n_seconds(
|
| 218 |
+
logging.INFO,
|
| 219 |
+
"ColTran FID eval: processed %i reference examples so far.", 30,
|
| 220 |
+
(i + 1) * self.device_batch_size)
|
| 221 |
+
|
| 222 |
+
if color_file_names != self.eval_file_names:
|
| 223 |
+
raise ValueError("unknown: {}\nmissing: {}".format(
|
| 224 |
+
color_file_names - self.eval_file_names,
|
| 225 |
+
self.eval_file_names - color_file_names))
|
| 226 |
+
|
| 227 |
+
if reference_file_names != self.reference_file_names:
|
| 228 |
+
raise ValueError("unknown: {}\nmissing: {}".format(
|
| 229 |
+
reference_file_names - self.reference_file_names,
|
| 230 |
+
self.reference_file_names - reference_file_names))
|
| 231 |
+
|
| 232 |
+
color = np.concatenate(color_pools, axis=0)
|
| 233 |
+
reference = np.concatenate(reference_pools, axis=0)
|
| 234 |
+
|
| 235 |
+
if color.shape[0] != 5_000:
|
| 236 |
+
raise ValueError(color.shape)
|
| 237 |
+
|
| 238 |
+
if reference.shape[0] != 5_000:
|
| 239 |
+
raise ValueError(reference.shape)
|
| 240 |
+
|
| 241 |
+
yield "FID_5k", tfgan.eval.frechet_classifier_distance_from_activations(
|
| 242 |
+
color, reference)
|
Tipsomaly/model/big_vision/evaluators/proj/uvim/coltran_fid_data/eval_file_names.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
Tipsomaly/model/big_vision/evaluators/proj/uvim/coltran_fid_data/reference_file_names.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
Tipsomaly/model/big_vision/evaluators/proj/uvim/common.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Common utilities used in evaluators."""
|
| 16 |
+
import math
|
| 17 |
+
import jax
|
| 18 |
+
import tensorflow as tf
|
| 19 |
+
import tensorflow_datasets as tfds
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def get_jax_process_dataset(dataset, split, global_batch_size, pp_fn,
|
| 23 |
+
dataset_dir=None, cache=True, add_tfds_id=False):
|
| 24 |
+
"""Returns dataset to be processed by current jax host.
|
| 25 |
+
|
| 26 |
+
The dataset is sharded and padded with zeros such that all processes
|
| 27 |
+
have equal number of batches. The first 2 dimensions of the dataset
|
| 28 |
+
elements are: [local_device_count, device_batch_size].
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
dataset: dataset name.
|
| 32 |
+
split: dataset split.
|
| 33 |
+
global_batch_size: batch size to be process per iteration on the dataset.
|
| 34 |
+
pp_fn: preprocessing function to apply per example.
|
| 35 |
+
dataset_dir: path for tfds to find the prepared data.
|
| 36 |
+
cache: whether to cache the dataset after batching.
|
| 37 |
+
add_tfds_id: whether to add the unique `tfds_id` string to each example.
|
| 38 |
+
"""
|
| 39 |
+
assert global_batch_size % jax.device_count() == 0
|
| 40 |
+
total_examples = tfds.load(
|
| 41 |
+
dataset, split=split, data_dir=dataset_dir).cardinality()
|
| 42 |
+
num_batches = math.ceil(total_examples / global_batch_size)
|
| 43 |
+
|
| 44 |
+
process_split = tfds.even_splits(
|
| 45 |
+
split, n=jax.process_count(), drop_remainder=False)[jax.process_index()]
|
| 46 |
+
data = tfds.load(
|
| 47 |
+
dataset,
|
| 48 |
+
split=process_split,
|
| 49 |
+
data_dir=dataset_dir,
|
| 50 |
+
read_config=tfds.ReadConfig(add_tfds_id=add_tfds_id)).map(pp_fn)
|
| 51 |
+
pad_data = tf.data.Dataset.from_tensors(
|
| 52 |
+
jax.tree_map(lambda x: tf.zeros(x.shape, x.dtype), data.element_spec)
|
| 53 |
+
).repeat()
|
| 54 |
+
|
| 55 |
+
data = data.concatenate(pad_data)
|
| 56 |
+
data = data.batch(global_batch_size // jax.device_count())
|
| 57 |
+
data = data.batch(jax.local_device_count())
|
| 58 |
+
data = data.take(num_batches)
|
| 59 |
+
if cache:
|
| 60 |
+
# Eval datasets are often used many times and caching the dataset after
|
| 61 |
+
# batching allows one to have the buffers ready to be used and not have
|
| 62 |
+
# to wait for preprocessing to be done over and over.
|
| 63 |
+
data = data.cache()
|
| 64 |
+
return data
|
Tipsomaly/model/big_vision/evaluators/proj/uvim/compute_mean.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Evaluator for computing mean of per-example metrics."""
|
| 16 |
+
import functools
|
| 17 |
+
from typing import Mapping
|
| 18 |
+
|
| 19 |
+
from big_vision import input_pipeline
|
| 20 |
+
from big_vision.datasets import core as ds_core
|
| 21 |
+
from big_vision.pp import builder as pp_builder
|
| 22 |
+
|
| 23 |
+
import jax
|
| 24 |
+
import jax.numpy as jnp
|
| 25 |
+
import numpy as np
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# Note: global to avoid jax re-compiling across different evaluator instances.
|
| 29 |
+
@functools.partial(jax.pmap, static_broadcasted_argnums=0, axis_name='batch')
|
| 30 |
+
def _run_predict_fn(predict_fn, params, batch):
|
| 31 |
+
"""Sum per-example metrics weighted by `_mask`."""
|
| 32 |
+
mask = batch['_mask']
|
| 33 |
+
metrics = predict_fn(params, batch)
|
| 34 |
+
# Sanity check output format of predict_fn.
|
| 35 |
+
assert isinstance(metrics, Mapping), 'predict_fn must return a dict'
|
| 36 |
+
for y in jax.tree_leaves(metrics):
|
| 37 |
+
if y.shape != mask.shape:
|
| 38 |
+
raise ValueError(
|
| 39 |
+
f'Expected per-example metrics of shape {mask.shape} found '
|
| 40 |
+
f'{jax.tree_map(lambda x: x.shape, metrics)}.')
|
| 41 |
+
metrics = {**metrics, '_mask': mask}
|
| 42 |
+
metrics = jax.tree_map(lambda x: jnp.inner(x, mask), metrics)
|
| 43 |
+
return jax.lax.psum(metrics, axis_name='batch')
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class Evaluator:
|
| 47 |
+
"""Report the mean of per-example metrics computed by predict_fn.
|
| 48 |
+
|
| 49 |
+
`predict_fn(params, batch)` must return a dict from metric name to
|
| 50 |
+
per-example metrics of shape [batch_size].
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
def __init__(self, predict_fn, data, pp_fn, batch_size,
|
| 54 |
+
cache_final=True, cache_raw=False, prefetch=1):
|
| 55 |
+
data = ds_core.get(**data)
|
| 56 |
+
self.dataset, self.steps = input_pipeline.make_for_inference(
|
| 57 |
+
data.get_tfdata(ordered=True), batch_size=batch_size,
|
| 58 |
+
num_ex_per_process=data.num_examples_per_process(),
|
| 59 |
+
preprocess_fn=pp_builder.get_preprocess_fn(pp_fn),
|
| 60 |
+
cache_final=cache_final, cache_raw=cache_raw)
|
| 61 |
+
self.data_iter = input_pipeline.start_input_pipeline(self.dataset, prefetch)
|
| 62 |
+
self.predict_fn = predict_fn
|
| 63 |
+
|
| 64 |
+
def run(self, params):
|
| 65 |
+
"""Computes all metrics."""
|
| 66 |
+
metrics = []
|
| 67 |
+
|
| 68 |
+
# Compute batch metrics without blocking.
|
| 69 |
+
for _, batch in zip(range(self.steps), self.data_iter):
|
| 70 |
+
batch_metrics = _run_predict_fn(self.predict_fn, params, batch)
|
| 71 |
+
metrics.append(batch_metrics)
|
| 72 |
+
|
| 73 |
+
# Transfer metrics from device 0 to host (blocking).
|
| 74 |
+
metrics = jax.device_get(jax.tree_map(lambda x: x[0], metrics))
|
| 75 |
+
|
| 76 |
+
metrics_sum = jax.tree_map(lambda *x: np.sum(x), *metrics)
|
| 77 |
+
mask_sum = metrics_sum.pop('_mask')
|
| 78 |
+
for key, value_sum in metrics_sum.items():
|
| 79 |
+
yield (key, value_sum / mask_sum)
|
Tipsomaly/model/big_vision/evaluators/proj/uvim/nyu_depth.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Evaluation for NYU depth.
|
| 16 |
+
|
| 17 |
+
At evaluation time the ground truth is cropped and clipped. Values outside of
|
| 18 |
+
the test crop or clipping range are not included in eval calculations.
|
| 19 |
+
|
| 20 |
+
In this evaluator, it is assume that the groud truth is already cropped, so the
|
| 21 |
+
entire image is evaluated. However, the evaluator does perform the clipping.
|
| 22 |
+
|
| 23 |
+
Reference implementations:
|
| 24 |
+
https://github.com/zhyever/Monocular-Depth-Estimation-Toolbox/blo(internal link)a0f341244260ff61541191a613dd74bc/depth/datasets/nyu.py
|
| 25 |
+
https://github.com/vinvino02/GLPDepth/blob/7f3c78df4ecd6e7c79fd0c4b73c95d61f4aa2121/code/utils/metrics.py
|
| 26 |
+
https://github.com/shariqfarooq123/AdaBins/blob/2fb686a66a304f0a719bc53d77412460af97fd61/evaluate.py
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
import functools
|
| 30 |
+
|
| 31 |
+
import big_vision.evaluators.proj.uvim.common as common
|
| 32 |
+
import big_vision.pp.builder as pp_builder
|
| 33 |
+
import jax
|
| 34 |
+
import jax.numpy as jnp
|
| 35 |
+
import numpy as np
|
| 36 |
+
import tensorflow as tf
|
| 37 |
+
|
| 38 |
+
EVAL_CROP_H = 426
|
| 39 |
+
EVAL_CROP_W = 560
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class Evaluator:
|
| 43 |
+
"""Evaluator for NYU depth."""
|
| 44 |
+
|
| 45 |
+
def __init__(self,
|
| 46 |
+
predict_fn,
|
| 47 |
+
pp_fn,
|
| 48 |
+
batch_size,
|
| 49 |
+
dataset,
|
| 50 |
+
split,
|
| 51 |
+
min_depth=1e-3,
|
| 52 |
+
max_depth=10,
|
| 53 |
+
dataset_dir=None,
|
| 54 |
+
predict_kwargs=None):
|
| 55 |
+
self.min_depth = min_depth
|
| 56 |
+
self.max_depth = max_depth
|
| 57 |
+
|
| 58 |
+
def predict(params, batch):
|
| 59 |
+
pred = predict_fn(params, batch, **(predict_kwargs or {}))
|
| 60 |
+
|
| 61 |
+
return jax.lax.all_gather({
|
| 62 |
+
"mask": batch["mask"],
|
| 63 |
+
"gt": jnp.squeeze(batch["ground_truth"], axis=-1),
|
| 64 |
+
"y": pred["depth"],
|
| 65 |
+
}, axis_name="data", axis=0)
|
| 66 |
+
|
| 67 |
+
self.predict_fn = jax.pmap(predict, axis_name="data")
|
| 68 |
+
|
| 69 |
+
# Prepare data for each process and pad with zeros so all processes have the
|
| 70 |
+
# same number of batches.
|
| 71 |
+
def preprocess(example):
|
| 72 |
+
return {
|
| 73 |
+
"mask": tf.constant(1),
|
| 74 |
+
**pp_builder.get_preprocess_fn(pp_fn)(example),
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
self.process_batch_size = batch_size // jax.process_count()
|
| 78 |
+
|
| 79 |
+
self.data = common.get_jax_process_dataset(
|
| 80 |
+
dataset=dataset,
|
| 81 |
+
dataset_dir=dataset_dir,
|
| 82 |
+
split=split,
|
| 83 |
+
global_batch_size=batch_size,
|
| 84 |
+
pp_fn=preprocess)
|
| 85 |
+
|
| 86 |
+
def run(self, params):
|
| 87 |
+
"""Run eval."""
|
| 88 |
+
# Assumes that the ground truth is processed by the eval crop.
|
| 89 |
+
eval_mask = np.ones((EVAL_CROP_H, EVAL_CROP_W), dtype=np.bool_)
|
| 90 |
+
rmses = []
|
| 91 |
+
abs_res = []
|
| 92 |
+
abs_logs = []
|
| 93 |
+
d1s = []
|
| 94 |
+
d2s = []
|
| 95 |
+
d3s = []
|
| 96 |
+
for batch in self.data.as_numpy_iterator():
|
| 97 |
+
# Outputs is a dict with values shaped (gather/same, devices, batch, ...)
|
| 98 |
+
out = self.predict_fn(params, batch)
|
| 99 |
+
|
| 100 |
+
if jax.process_index(): # Host0 gets all preds and does eval.
|
| 101 |
+
continue
|
| 102 |
+
|
| 103 |
+
# First, we remove the "gather" dim and transfer the result to host,
|
| 104 |
+
# leading to numpy arrays of (devices, device_batch, ...)
|
| 105 |
+
out = jax.tree_map(lambda x: jax.device_get(x[0]), out)
|
| 106 |
+
# Then the bool-indexing with mask resulting in flat (global_batch, ...)
|
| 107 |
+
out = jax.tree_map(lambda x: x[out["mask"] == 1], out) # pylint:disable=cell-var-from-loop
|
| 108 |
+
|
| 109 |
+
for gt, pred in zip(out["gt"], out["y"]):
|
| 110 |
+
pred = _resize_nearest(pred, (EVAL_CROP_H, EVAL_CROP_W))
|
| 111 |
+
valid_mask = np.logical_and(gt > self.min_depth, gt < self.max_depth)
|
| 112 |
+
valid_mask = np.logical_and(valid_mask, eval_mask)
|
| 113 |
+
|
| 114 |
+
rmses.append(_compute_rmse(gt[valid_mask], pred[valid_mask]))
|
| 115 |
+
abs_res.append(_compute_abs_re(gt[valid_mask], pred[valid_mask]))
|
| 116 |
+
abs_logs.append(_compute_abs_log(gt[valid_mask], pred[valid_mask]))
|
| 117 |
+
d1s.append(_compute_delta(gt[valid_mask], pred[valid_mask], order=1))
|
| 118 |
+
d2s.append(_compute_delta(gt[valid_mask], pred[valid_mask], order=2))
|
| 119 |
+
d3s.append(_compute_delta(gt[valid_mask], pred[valid_mask], order=3))
|
| 120 |
+
|
| 121 |
+
if jax.process_index(): # Host0 gets all preds and does eval.
|
| 122 |
+
return
|
| 123 |
+
|
| 124 |
+
yield "RMSE", np.mean(rmses)
|
| 125 |
+
yield "abs_RE", np.mean(abs_res)
|
| 126 |
+
yield "log10", np.mean(abs_logs)
|
| 127 |
+
yield "delta1", np.mean(d1s)
|
| 128 |
+
yield "delta2", np.mean(d2s)
|
| 129 |
+
yield "delta3", np.mean(d3s)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
@functools.partial(jax.jit, static_argnums=(1,), backend="cpu")
|
| 133 |
+
def _resize_nearest(image, shape):
|
| 134 |
+
return jax.image.resize(image, shape, "nearest")
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def _compute_rmse(gt, pred):
|
| 138 |
+
diff = gt - pred
|
| 139 |
+
return np.sqrt(np.mean(np.power(diff, 2)))
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def _compute_abs_re(gt, pred):
|
| 143 |
+
diff = np.abs(gt - pred)
|
| 144 |
+
return np.mean(diff / gt)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def _compute_abs_log(gt, pred):
|
| 148 |
+
diff = np.abs(np.log10(gt) - np.log10(pred))
|
| 149 |
+
return np.mean(diff)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def _compute_delta(gt, pred, order):
|
| 153 |
+
rel_diff = np.maximum(gt / pred, pred / gt)
|
| 154 |
+
return np.sum(rel_diff < 1.25**order) / rel_diff.size
|