AlirezaSalehi99 commited on
Commit
ebe754f
·
verified ·
1 Parent(s): 0c3fcb9

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +6 -0
  2. Tipsomaly/model/big_vision/configs/proj/givt/givt_overview.png +3 -0
  3. Tipsomaly/model/big_vision/configs/proj/jetformer/jetformer_overview.png +3 -0
  4. Tipsomaly/model/big_vision/configs/proj/paligemma/paligemma.png +3 -0
  5. Tipsomaly/model/big_vision/datasets/countbenchqa/data/countbench_paired_questions.json +1 -0
  6. Tipsomaly/model/big_vision/datasets/imagenet/__pycache__/class_names.cpython-311.pyc +3 -0
  7. Tipsomaly/model/big_vision/datasets/imagenet/__pycache__/class_names.cpython-312.pyc +3 -0
  8. Tipsomaly/model/big_vision/datasets/imagenet/__pycache__/class_names.cpython-39.pyc +3 -0
  9. Tipsomaly/model/big_vision/datasets/nocaps/nocaps.py +160 -0
  10. Tipsomaly/model/big_vision/datasets/refcoco/refcoco.py +448 -0
  11. Tipsomaly/model/big_vision/datasets/rsvqa_hr/rsvqa_hr.py +193 -0
  12. Tipsomaly/model/big_vision/datasets/rsvqa_lr/rsvqa_lr.py +198 -0
  13. Tipsomaly/model/big_vision/datasets/scicap/scicap.py +205 -0
  14. Tipsomaly/model/big_vision/datasets/science_qa/science_qa.py +156 -0
  15. Tipsomaly/model/big_vision/datasets/screen2words/screen2words.py +120 -0
  16. Tipsomaly/model/big_vision/datasets/stvqa/stvqa.py +134 -0
  17. Tipsomaly/model/big_vision/datasets/tallyqa/tallyqa.py +146 -0
  18. Tipsomaly/model/big_vision/datasets/textcaps/textcaps.py +152 -0
  19. Tipsomaly/model/big_vision/datasets/textvqa/textvqa.py +186 -0
  20. Tipsomaly/model/big_vision/datasets/vizwizvqa/vizwizvqa.py +128 -0
  21. Tipsomaly/model/big_vision/datasets/vqa/vqa.py +147 -0
  22. Tipsomaly/model/big_vision/datasets/widgetcap/widgetcap.py +151 -0
  23. Tipsomaly/model/big_vision/datasets/xgqa/xgqa.py +145 -0
  24. Tipsomaly/model/big_vision/datasets/xm3600/xm3600.py +136 -0
  25. Tipsomaly/model/big_vision/evaluators/proj/cappa/perplexity.py +50 -0
  26. Tipsomaly/model/big_vision/evaluators/proj/cappa/scoring_classifier.py +63 -0
  27. Tipsomaly/model/big_vision/evaluators/proj/distill/distance.py +151 -0
  28. Tipsomaly/model/big_vision/evaluators/proj/givt/coco_panoptic.py +401 -0
  29. Tipsomaly/model/big_vision/evaluators/proj/givt/nyu_depth.py +191 -0
  30. Tipsomaly/model/big_vision/evaluators/proj/givt/save_predictions.py +118 -0
  31. Tipsomaly/model/big_vision/evaluators/proj/image_text/contrastive.py +99 -0
  32. Tipsomaly/model/big_vision/evaluators/proj/image_text/discriminative_classifier.py +440 -0
  33. Tipsomaly/model/big_vision/evaluators/proj/image_text/discriminative_classifier_test.py +237 -0
  34. Tipsomaly/model/big_vision/evaluators/proj/image_text/image_text_retrieval.py +85 -0
  35. Tipsomaly/model/big_vision/evaluators/proj/image_text/image_text_retrieval_test.py +86 -0
  36. Tipsomaly/model/big_vision/evaluators/proj/image_text/prompt_engineering.py +112 -0
  37. Tipsomaly/model/big_vision/evaluators/proj/image_text/prompt_engineering_constants.py +110 -0
  38. Tipsomaly/model/big_vision/evaluators/proj/image_text/prompt_engineering_test.py +48 -0
  39. Tipsomaly/model/big_vision/evaluators/proj/image_text/retrieval.py +306 -0
  40. Tipsomaly/model/big_vision/evaluators/proj/image_text/retrieval_test.py +178 -0
  41. Tipsomaly/model/big_vision/evaluators/proj/paligemma/perplexity.py +63 -0
  42. Tipsomaly/model/big_vision/evaluators/proj/paligemma/transfers/chartqa.py +139 -0
  43. Tipsomaly/model/big_vision/evaluators/proj/paligemma/transfers/pope.py +135 -0
  44. Tipsomaly/model/big_vision/evaluators/proj/uvim/coco_panoptic.py +324 -0
  45. Tipsomaly/model/big_vision/evaluators/proj/uvim/coltran_fid.py +242 -0
  46. Tipsomaly/model/big_vision/evaluators/proj/uvim/coltran_fid_data/eval_file_names.txt +0 -0
  47. Tipsomaly/model/big_vision/evaluators/proj/uvim/coltran_fid_data/reference_file_names.txt +0 -0
  48. Tipsomaly/model/big_vision/evaluators/proj/uvim/common.py +64 -0
  49. Tipsomaly/model/big_vision/evaluators/proj/uvim/compute_mean.py +79 -0
  50. Tipsomaly/model/big_vision/evaluators/proj/uvim/nyu_depth.py +154 -0
.gitattributes CHANGED
@@ -36,3 +36,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
36
  Tipsomaly/imgs/Models_Architecture_page-0001.jpg filter=lfs diff=lfs merge=lfs -text
37
  Tipsomaly/imgs/results-table.png filter=lfs diff=lfs merge=lfs -text
38
  Tipsomaly/imgs/Qualitative_results_page-0001.jpg filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
36
  Tipsomaly/imgs/Models_Architecture_page-0001.jpg filter=lfs diff=lfs merge=lfs -text
37
  Tipsomaly/imgs/results-table.png filter=lfs diff=lfs merge=lfs -text
38
  Tipsomaly/imgs/Qualitative_results_page-0001.jpg filter=lfs diff=lfs merge=lfs -text
39
+ Tipsomaly/model/big_vision/datasets/imagenet/__pycache__/class_names.cpython-312.pyc filter=lfs diff=lfs merge=lfs -text
40
+ Tipsomaly/model/big_vision/datasets/imagenet/__pycache__/class_names.cpython-39.pyc filter=lfs diff=lfs merge=lfs -text
41
+ Tipsomaly/model/big_vision/datasets/imagenet/__pycache__/class_names.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
42
+ Tipsomaly/model/big_vision/configs/proj/givt/givt_overview.png filter=lfs diff=lfs merge=lfs -text
43
+ Tipsomaly/model/big_vision/configs/proj/jetformer/jetformer_overview.png filter=lfs diff=lfs merge=lfs -text
44
+ Tipsomaly/model/big_vision/configs/proj/paligemma/paligemma.png filter=lfs diff=lfs merge=lfs -text
Tipsomaly/model/big_vision/configs/proj/givt/givt_overview.png ADDED

Git LFS Details

  • SHA256: 5c9ab0e72ea3dab6d2997cfa20074bae3dad71d1c7a9d0fff15fb4d453fa8a85
  • Pointer size: 131 Bytes
  • Size of remote file: 803 kB
Tipsomaly/model/big_vision/configs/proj/jetformer/jetformer_overview.png ADDED

Git LFS Details

  • SHA256: 5f24740d8fb552cda73e0e06874aa888fc60a4b47b76e3870a85eccc2c716699
  • Pointer size: 131 Bytes
  • Size of remote file: 553 kB
Tipsomaly/model/big_vision/configs/proj/paligemma/paligemma.png ADDED

Git LFS Details

  • SHA256: e3258bcf799cf0bbf7f5181e53c85339630b60476f8af9a95ef6432d0ad56f54
  • Pointer size: 131 Bytes
  • Size of remote file: 367 kB
Tipsomaly/model/big_vision/datasets/countbenchqa/data/countbench_paired_questions.json ADDED
@@ -0,0 +1 @@
 
 
1
+ [{"question": "How many headsets are there in the image?"}, {"question": "How many light bulbs are there in the image?"}, {"question": "How many prints are there in the image?"}, {"question": "How many arrows are there in the image?"}, {"question": "How many spoons are there in the image?"}, {"question": "How many girls are there in the image?"}, {"question": "How many parrots are there in the image?"}, {"question": "How many coloring pages are there in the image?"}, {"question": "How many food containers are there in the image?"}, {"question": "How many birdhouse patterns are there in the image?"}, {"question": "How many sofas are there in the image?"}, {"question": "How many waterlilies are there in the image?"}, {"question": "How many chairs are there in the image?"}, {"question": "How many eggs are there in the image?"}, {"question": "How many golfers are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many outfits are there in the image?"}, {"question": "How many pigs are there in the image?"}, {"question": "How many cars are there in the image?"}, {"question": "How many aum symbols are there in the image?"}, {"question": "How many men are there in the image?"}, {"question": "How many chairs are there in the image?"}, {"question": "How many buttons are there in the image?"}, {"question": "How many rackets are there in the image?"}, {"question": "How many pots are there in the image?"}, {"question": "How many stars are there in the image?"}, {"question": "How many chairs are there in the image?"}, {"question": "How many dice are there in the image?"}, {"question": "How many silhouettes are there in the image?"}, {"question": "How many kids are there in the image?"}, {"question": "How many moth silhouettes are there in the image?"}, {"question": "How many pumpkin candles are there in the image?"}, {"question": "How many essential oils are there in the image?"}, {"question": "How many stencils are there in the image?"}, {"question": "How many text boxes are there in the image?"}, {"question": "How many basketball players are there in the image?"}, {"question": "How many photos are there in the image?"}, {"question": "How many forks are there in the image?"}, {"question": "How many icons are there in the image?"}, {"question": "How many bags are there in the image?"}, {"question": "How many couples are there in the image?"}, {"question": "How many weights are there in the image?"}, {"question": "How many women are there in the image?"}, {"question": "How many fish are there in the image?"}, {"question": "How many icons are there in the image?"}, {"question": "How many clocks are there in the image?"}, {"question": "How many banners are there in the image?"}, {"question": "How many chests are there in the image?"}, {"question": "How many stars are there in the image?"}, {"question": "How many men are there in the image?"}, {"question": "How many globe icons are there in the image?"}, {"question": "How many posters are there in the image?"}, {"question": "How many socks are there in the image?"}, {"question": "How many men are there in the image?"}, {"question": "How many snails are there in the image?"}, {"question": "How many crochet potholders are there in the image?"}, {"question": "How many christmas cards are there in the image?"}, {"question": "How many double beds are there in the image?"}, {"question": "How many baseball players are there in the image?"}, {"question": "How many icons are there in the image?"}, {"question": "How many symbols are there in the image?"}, {"question": "How many baseball players are there in the image?"}, {"question": "How many cars are there in the image?"}, {"question": "How many trees are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many tree trunk cuts are there in the image?"}, {"question": "How many women are there in the image?"}, {"question": "How many chairs are there in the image?"}, {"question": "How many icons are there in the image?"}, {"question": "How many banners are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many individual earrings are there in the image?"}, {"question": "How many men are there in the image?"}, {"question": "How many tomatoes are there in the image?"}, {"question": "How many chairs are there in the image?"}, {"question": "How many prints are there in the image?"}, {"question": "How many ice creams are there in the image?"}, {"question": "How many plates are there in the image?"}, {"question": "How many sumo wrestlers are there in the image?"}, {"question": "How many compositions are there in the image?"}, {"question": "How many PVC vinyls are there in the image?"}, {"question": "How many photos of fruit are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many icons are there in the image?"}, {"question": "How many bottles are there in the image?"}, {"question": "How many mirrors are there in the image?"}, {"question": "How many groomsmen are there in the image?"}, {"question": "How many posters are there in the image?"}, {"question": "How many pumpkins are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many kittens are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many boots are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many figurines are there in the image?"}, {"question": "How many moais are there in the image?"}, {"question": "How many icons are there in the image?"}, {"question": "How many cats are there in the image?"}, {"question": "How many wallpaper variants are there in the image?"}, {"question": "How many nail polishes are there in the image?"}, {"question": "How many bumble bees are there in the image?"}, {"question": "How many tickets are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many wine bottles are there in the image?"}, {"question": "How many silhouettes of couples are there in the image?"}, {"question": "How many owls are there in the image?"}, {"question": "How many candles are there in the image?"}, {"question": "How many flower pots are there in the image?"}, {"question": "How many coins are there in the image?"}, {"question": "How many placemats are there in the image?"}, {"question": "How many women are there in the image?"}, {"question": "How many bottles are there in the image?"}, {"question": "How many cars are there in the image?"}, {"question": "How many men are there in the image?"}, {"question": "How many candles are there in the image?"}, {"question": "How many pairs of socks are there in the image?"}, {"question": "How many blinds are there in the image?"}, {"question": "How many floral patterns are there in the image?"}, {"question": "How many icons are there in the image?"}, {"question": "How many dice are there in the image?"}, {"question": "How many bicycles are there in the image?"}, {"question": "How many dwarfs are there in the image?"}, {"question": "How many stickers are there in the image?"}, {"question": "How many symbols are there in the image?"}, {"question": "How many stamps are there in the image?"}, {"question": "How many pumpkins are there in the image?"}, {"question": "How many game cartridges are there in the image?"}, {"question": "How many chairs are there in the image?"}, {"question": "How many glasses are there in the image?"}, {"question": "How many bottles are there in the image?"}, {"question": "How many coins are there in the image?"}, {"question": "How many gears are there in the image?"}, {"question": "How many flowers are there in the image?"}, {"question": "How many crip packages are there in the image?"}, {"question": "How many bridesmaids are there in the image?"}, {"question": "How many apples are there in the image?"}, {"question": "How many bowls are there in the image?"}, {"question": "How many women are there in the image?"}, {"question": "How many photographs are there in the image?"}, {"question": "How many warning signs are there in the image?"}, {"question": "How many children are there in the image?"}, {"question": "How many cyclists are there in the image?"}, {"question": "How many women are there in the image?"}, {"question": "How many chairs are there in the image?"}, {"question": "How many giraffes are there in the image?"}, {"question": "How many men are there in the image?"}, {"question": "How many male nurses are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many couples are there in the image?"}, {"question": "How many cars are there in the image?"}, {"question": "How many apples are there in the image?"}, {"question": "How many smartphones are there in the image?"}, {"question": "How many roses are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many greeting cards are there in the image?"}, {"question": "How many guitars are there in the image?"}, {"question": "How many ironman suits are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many CDs are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many symbols are there in the image?"}, {"question": "How many chairs are there in the image?"}, {"question": "How many pot holders are there in the image?"}, {"question": "How many stamps are there in the image?"}, {"question": "How many bookmarks are there in the image?"}, {"question": "How many portraits are there in the image?"}, {"question": "How many girls are there in the image?"}, {"question": "How many labels are there in the image?"}, {"question": "How many mandalas are there in the image?"}, {"question": "How many silhouettes are there in the image?"}, {"question": "How many peacocks are there in the image?"}, {"question": "How many roses are there in the image?"}, {"question": "How many cars are there in the image?"}, {"question": "How many children are there in the image?"}, {"question": "How many chairs are there in the image?"}, {"question": "How many symbols are there in the image?"}, {"question": "How many canvases are there in the image?"}, {"question": "How many cards are there in the image?"}, {"question": "How many bell pepper halves are there in the image?"}, {"question": "How many pigs are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many newspapers are there in the image?"}, {"question": "How many leggings are there in the image?"}, {"question": "How many medals are there in the image?"}, {"question": "How many patterns are there in the image?"}, {"question": "How many women are there in the image?"}, {"question": "How many doors are there in the image?"}, {"question": "How many pairs of socks are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many backgrounds are there in the image?"}, {"question": "How many images of dogs are there in the image?"}, {"question": "How many broadheads are there in the image?"}, {"question": "How many figurines are there in the image?"}, {"question": "How many candles are there in the image?"}, {"question": "How many petals does each flower have in this image?"}, {"question": "How many crayons are there in the image?"}, {"question": "How many men are there in the image?"}, {"question": "How many caps are there in the image?"}, {"question": "How many children are there in the image?"}, {"question": "How many frames are there in the image?"}, {"question": "How many animals are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many cats are there in the image?"}, {"question": "How many sconces are there in the image?"}, {"question": "How many spoons are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many icons are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many planes are there in the image?"}, {"question": "How many cats are there in the image?"}, {"question": "How many sketches are there in the image?"}, {"question": "How many trees are there in the image?"}, {"question": "How many dice are there in the image?"}, {"question": "How many figurines are there in the image?"}, {"question": "How many napkins are there in the image?"}, {"question": "How many dogs are there in the image?"}, {"question": "How many kids are there in the image?"}, {"question": "How many hearts are there in the image?"}, {"question": "How many apples are there in the image?"}, {"question": "How many post-its are there in the image?"}, {"question": "How many dice are there in the image?"}, {"question": "How many books are there in the image?"}, {"question": "How many women are there in the image?"}, {"question": "How many sofas are there in the image?"}, {"question": "How many dice are there in the image?"}, {"question": "How many girls are there in the image?"}, {"question": "How many sketches are there in the image?"}, {"question": "How many men are there in the image?"}, {"question": "How many moons are there in the image?"}, {"question": "How many labels are there in the image?"}, {"question": "How many cylinders are there in the image?"}, {"question": "How many silhouettes of couples are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many dogs are there in the image?"}, {"question": "How many people on stage are there in the image?"}, {"question": "How many lambs are there in the image?"}, {"question": "How many violins are there in the image?"}, {"question": "How many armchairs are there in the image?"}, {"question": "How many silhouettes are there in the image?"}, {"question": "How many eggs are there in the image?"}, {"question": "How many photos are there in the image?"}, {"question": "How many chairs are there in the image?"}, {"question": "How many figurines are there in the image?"}, {"question": "How many moais are there in the image?"}, {"question": "How many symbols are there in the image?"}, {"question": "How many glasses are there in the image?"}, {"question": "How many sinks are there in the image?"}, {"question": "How many women are there in the image?"}, {"question": "How many buildings are there in the image?"}, {"question": "How many flyers are there in the image?"}, {"question": "How many dice are there in the image?"}, {"question": "How many pizza slices are there in the image?"}, {"question": "How many stamps are there in the image?"}, {"question": "How many stickers are there in the image?"}, {"question": "How many gifts are there in the image?"}, {"question": "How many bowls are there in the image?"}, {"question": "How many onesies are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many posters are there in the image?"}, {"question": "How many cards are there in the image?"}, {"question": "How many portraits are there in the image?"}, {"question": "How many poinsettias are there in the image?"}, {"question": "How many chicken thighs are there in the image?"}, {"question": "How many glass windows are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many baskets are there in the image?"}, {"question": "How many tulips are there in the image?"}, {"question": "How many bottles are there in the image?"}, {"question": "How many stuffed animals are there in the image?"}, {"question": "How many keychains are there in the image?"}, {"question": "How many photos are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many adults are there in the image?"}, {"question": "How many tulips are there in the image?"}, {"question": "How many frames are there in the image?"}, {"question": "How many samosas are there in the image?"}, {"question": "How many strawberries are there in the image?"}, {"question": "How many cocktails are there in the image?"}, {"question": "How many men are there in the image?"}, {"question": "How many wineglasses are there in the image?"}, {"question": "How many goblets are there in the image?"}, {"question": "How many prints are there in the image?"}, {"question": "How many flowers are there in the image?"}, {"question": "How many zebras are there in the image?"}, {"question": "How many paint brushes are there in the image?"}, {"question": "How many prints are there in the image?"}, {"question": "How many figurines are there in the image?"}, {"question": "How many glasses are there in the image?"}, {"question": "How many sunglasses are there in the image?"}, {"question": "How many banners are there in the image?"}, {"question": "How many bottle caps are there in the image?"}, {"question": "How many prints are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many eggs are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many books are there in the image?"}, {"question": "How many animals are there in the image?"}, {"question": "How many eyeshadows are there in the image?"}, {"question": "How many keychains are there in the image?"}, {"question": "How many pairs of earrings are there in the image?"}, {"question": "How many canisters are there in the image?"}, {"question": "How many bags are there in the image?"}, {"question": "How many baking trays are there in the image?"}, {"question": "How many diamonds are there in the image?"}, {"question": "How many portraits are there in the image?"}, {"question": "How many framed images are there in the image?"}, {"question": "How many flags are there in the image?"}, {"question": "How many chairs are there in the image?"}, {"question": "How many framed pictures are there in the image?"}, {"question": "How many eggs are there in the image?"}, {"question": "How many croissants are there in the image?"}, {"question": "How many Manikins are there in the image?"}, {"question": "How many symbols are there in the image?"}, {"question": "How many labels are there in the image?"}, {"question": "How many people in the foreground are there in the image?"}, {"question": "How many armchairs are there in the image?"}, {"question": "How many cups are there in the image?"}, {"question": "How many helicopters are there in the image?"}, {"question": "How many icons are there in the image?"}, {"question": "How many children are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many buttons are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many Cookies are there in the image?"}, {"question": "How many sculptures are there in the image?"}, {"question": "How many school uniforms are there in the image?"}, {"question": "How many sculptures are there in the image?"}, {"question": "How many eggs are there in the image?"}, {"question": "How many packages are there in the image?"}, {"question": "How many bottles are there in the image?"}, {"question": "How many stories does this cottage have?"}, {"question": "How many gift cards are there in the image?"}, {"question": "How many eggs are there in the image?"}, {"question": "How many beers are there in the image?"}, {"question": "How many children are there in the image?"}, {"question": "How many planes are there in the image?"}, {"question": "How many cactus pots are there in the image?"}, {"question": "How many smartphones are there in the image?"}, {"question": "How many picture frames are there in the image?"}, {"question": "How many elephants are there in the image?"}, {"question": "How many guitars are there in the image?"}, {"question": "How many samurai are there in the image?"}, {"question": "How many ghosts are there in the image?"}, {"question": "How many bottles are there in the image?"}, {"question": "How many candles are there in the image?"}, {"question": "How many vases are there in the image?"}, {"question": "How many sets of headphones are there in the image?"}, {"question": "How many pandas are there in the image?"}, {"question": "How many books are there in the image?"}, {"question": "How many stickers are there in the image?"}, {"question": "How many rings are there in the image?"}, {"question": "How many dice are there in the image?"}, {"question": "How many dragon balls are there in the image?"}, {"question": "How many tanks are there in the image?"}, {"question": "How many students are there in the image?"}, {"question": "How many cups are there in the image?"}, {"question": "How many cubs are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many cushions are there in the image?"}, {"question": "How many shoes are there in the image?"}, {"question": "How many beers are there in the image?"}, {"question": "How many wine glasses are there in the image?"}, {"question": "How many cards are there in the image?"}, {"question": "How many boots are there in the image?"}, {"question": "How many stickers are there in the image?"}, {"question": "How many dice are there in the image?"}, {"question": "How many butterflies are there in the image?"}, {"question": "How many women are there in the image?"}, {"question": "How many butterflies are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many chairs are there in the image?"}, {"question": "How many hexagons are there in the image?"}, {"question": "How many chairs are there in the image?"}, {"question": "How many football players are there in the image?"}, {"question": "How many gifts are there in the image?"}, {"question": "How many light bulbs are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many chairs are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many women are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many candles are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many silhouettes are there in the image?"}, {"question": "How many pots are there in the image?"}, {"question": "How many pumpkins are there in the image?"}, {"question": "How many owls are there in the image?"}, {"question": "How many doctors are there in the image?"}, {"question": "How many pigs are there in the image?"}, {"question": "How many pillars are there in the image?"}, {"question": "How many symbols are there in the image?"}, {"question": "How many pumpkins are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many eggs are there in the image?"}, {"question": "How many roses are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many bottles are there in the image?"}, {"question": "How many illustrations are there in the image?"}, {"question": "How many photos are there in the image?"}, {"question": "How many butterflies are there in the image?"}, {"question": "How many spoons are there in the image?"}, {"question": "How many potato spreads are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many women are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many wall arts are there in the image?"}, {"question": "How many covers are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many soldiers are there in the image?"}, {"question": "How many posters are there in the image?"}, {"question": "How many chairs are there in the image?"}, {"question": "How many patterns are there in the image?"}, {"question": "How many dogs are there in the image?"}, {"question": "How many pennies are there in the image?"}, {"question": "How many windows are there in the image?"}, {"question": "How many eggs are there in the image?"}, {"question": "How many geese are there in the image?"}, {"question": "How many tickets are there in the image?"}, {"question": "How many figurines are there in the image?"}, {"question": "How many cases are there in the image?"}, {"question": "How many people are in the foreground of this image?"}, {"question": "How many glasses are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many cups are there in the image?"}, {"question": "How many dice are there in the image?"}, {"question": "How many kittens are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many scones are there in the image?"}, {"question": "How many schoolgirls are there in the image?"}, {"question": "How many padlocks are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many windows are there in the image?"}, {"question": "How many icons are there in the image?"}, {"question": "How many banners are there in the image?"}, {"question": "How many dogs are there in the image?"}, {"question": "How many groomsmen are there in the image?"}, {"question": "How many children are there in the image?"}, {"question": "How many glasses are there in the image?"}, {"question": "How many icons are there in the image?"}, {"question": "How many coins are there in the image?"}, {"question": "How many trees are there in the image?"}, {"question": "How many turtles are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many frames are there in the image?"}, {"question": "How many firescreens are there in the image?"}, {"question": "How many bowls are there in the image?"}, {"question": "How many stickers are there in the image?"}, {"question": "How many archaic mirrors are there in the image?"}, {"question": "How many dogs are there in the image?"}, {"question": "How many candles are there in the image?"}, {"question": "How many paint brushes are there in the image?"}, {"question": "How many forks are there in the image?"}, {"question": "How many figurines are there in the image?"}, {"question": "How many owls are there in the image?"}, {"question": "How many icons are there in the image?"}, {"question": "How many silhouettes are there in the image?"}, {"question": "How many fried eggs are there in the image?"}, {"question": "How many guitars are there in the image?"}, {"question": "How many signs are there in the image?"}, {"question": "How many watches are there in the image?"}, {"question": "How many women are there in the image?"}, {"question": "How many children are there in the image?"}, {"question": "How many kittens are there in the image?"}, {"question": "How many kids are there in the image?"}, {"question": "How many kittens are there in the image?"}, {"question": "How many toys are there in the image?"}, {"question": "How many pairs of socks are there in the image?"}, {"question": "How many pairs of socks are there in the image?"}, {"question": "How many men are there in the image?"}, {"question": "How many stars are there in the image?"}, {"question": "How many quarters are there in the image?"}, {"question": "How many women are there in the image?"}, {"question": "How many chairs are there in the image?"}, {"question": "How many diamonds are there in the image?"}, {"question": "How many moais are there in the image?"}, {"question": "How many banners are there in the image?"}, {"question": "How many silhouettes are there in the image?"}, {"question": "How many coins are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many banners are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many flags are there in the image?"}, {"question": "How many frames are there in the image?"}, {"question": "How many contestants are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many globes are there in the image?"}, {"question": "How many animals are there in the image?"}, {"question": "How many bottles are there in the image?"}, {"question": "How many buttons are there in the image?"}, {"question": "How many icons are there in the image?"}, {"question": "How many piles of candy are there in the image?"}, {"question": "How many pictures are there in the image?"}, {"question": "How many plates are there in the image?"}, {"question": "How many calendars are there in the image?"}, {"question": "How many oranges are there in the image?"}, {"question": "How many puppies are there in the image?"}, {"question": "How many buffalos are there in the image?"}, {"question": "How many bottles are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many icons are there in the image?"}, {"question": "How many christmas balls are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many armchairs are there in the image?"}, {"question": "How many frames are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many soldiers are there in the image?"}, {"question": "How many chairs are there in the image?"}, {"question": "How many chairs are there in the image?"}, {"question": "How many cards are there in the image?"}, {"question": "How many colored tiles are there in the image?"}, {"question": "How many trapezoids are there in the image?"}, {"question": "How many pastries are there in the image?"}, {"question": "How many plants are there in the image?"}, {"question": "How many place mats are there in the image?"}, {"question": "How many symbols are there in the image?"}, {"question": "How many mazes are there in the image?"}, {"question": "How many tea bags are there in the image?"}, {"question": "How many photos are there in the image?"}, {"question": "How many children are there in the image?"}, {"question": "How many starfish are there in the image?"}, {"question": "How many mugs are there in the image?"}, {"question": "How many bottles are there in the image?"}, {"question": "How many eggs are there in the image?"}, {"question": "How many icons are there in the image?"}, {"question": "How many balls are there in the image?"}, {"question": "How many paper bags are there in the image?"}, {"question": "How many garage doors are there in the image?"}, {"question": "How many silhouettes are there in the image?"}, {"question": "How many couples are there in the image?"}, {"question": "How many dice are there in the image?"}, {"question": "How many women are there in the image?"}, {"question": "How many icons are there in the image?"}, {"question": "How many wooden spoons are there in the image?"}, {"question": "How many people are there in the image?"}, {"question": "How many chairs are there in the image?"}, {"question": "How many chairs are there in the image?"}, {"question": "How many beds are there in the image?"}, {"question": "How many chairs are there in the image?"}]
Tipsomaly/model/big_vision/datasets/imagenet/__pycache__/class_names.cpython-311.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:60133d938ba72366456dec978d305ff978b143577b047fe46aee4b50b104eef6
3
+ size 272122
Tipsomaly/model/big_vision/datasets/imagenet/__pycache__/class_names.cpython-312.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a565009d6d842a49ccfc7f31d9945d1916fa9669913290d3483a4b9d360acbd7
3
+ size 272065
Tipsomaly/model/big_vision/datasets/imagenet/__pycache__/class_names.cpython-39.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fba8a44d65d0f986bf5562ae0290a3cec1db9af239f9c0f91876099c173b760b
3
+ size 272046
Tipsomaly/model/big_vision/datasets/nocaps/nocaps.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # pylint: disable=line-too-long
16
+ r"""Implements nocaps val/test set in TFDS structure.
17
+
18
+ It's small data, so simple to run locally. First, copy the data to local disk:
19
+
20
+ mkdir -p /tmp/data/nocaps_data
21
+ cd /tmp/data/nocaps_data
22
+ wget https://s3.amazonaws.com/open-images-dataset/tar/test.tar.gz
23
+ wget https://s3.amazonaws.com/open-images-dataset/tar/validation.tar.gz
24
+ curl -O https://nocaps.s3.amazonaws.com/nocaps_val_4500_captions.json
25
+ curl -O https://s3.amazonaws.com/nocaps/nocaps_test_image_info.json
26
+
27
+ mkdir -p /tmp/data/nocaps_data/Images
28
+ tar -xf validation.tar.gz -C Images
29
+ rm validation.tar.gz
30
+ tar -xf test.tar.gz -C Images
31
+ rm test.tar.gz
32
+
33
+ Then, run conversion locally (make sure to install tensorflow-datasets for the `tfds` util):
34
+
35
+ cd big_vision/datasets
36
+ env TFDS_DATA_DIR=/tmp/tfds tfds build --datasets=nocaps
37
+
38
+ Example to load:
39
+
40
+ import tensorflow_datasets as tfds
41
+ dataset = tfds.load('nocaps', split='val', data_dir='/tmp/tfds')
42
+ """
43
+ import collections
44
+ import json
45
+ import os
46
+
47
+ from absl import logging
48
+ import numpy as np
49
+ import tensorflow as tf
50
+ import tensorflow_datasets as tfds
51
+
52
+
53
+ _DESCRIPTION = """Nocaps dataset."""
54
+
55
+ _CITATION = (
56
+ '@inproceedings{agrawal2019nocaps,'
57
+ 'title={nocaps: novel object captioning at scale},'
58
+ 'author={Agrawal, Harsh and Desai, Karan and Wang, Yufei and Chen, Xinlei'
59
+ 'and Jain, Rishabh and Johnson, Mark and Batra, Dhruv and Parikh, Devi'
60
+ 'and Lee, Stefan and Anderson, Peter},'
61
+ 'booktitle={ICCV},'
62
+ 'pages={8948--8957},'
63
+ 'year={2019}}')
64
+
65
+ # When running locally (recommended), copy files as above an use these:
66
+ _FILEPATH = '/tmp/data/nocaps_data/Images/'
67
+ _VAL_FILES = '/tmp/data/nocaps_data/nocaps_val_4500_captions.json'
68
+ _TEST_FILES = '/tmp/data/nocaps_data/nocaps_test_image_info.json'
69
+
70
+
71
+ class NoCaps(tfds.core.GeneratorBasedBuilder):
72
+ """DatasetBuilder for nocaps dataset."""
73
+
74
+ VERSION = tfds.core.Version('1.0.0')
75
+ RELEASE_NOTES = {
76
+ '1.0.0': 'Initial release.',
77
+ }
78
+
79
+ def _info(self) -> tfds.core.DatasetInfo:
80
+ """Returns the dataset metadata.
81
+
82
+ (tfds.core.DatasetInfo object)
83
+ These are the features of your dataset like images, labels, etc.
84
+ """
85
+
86
+ return tfds.core.DatasetInfo(
87
+ builder=self,
88
+ description=_DESCRIPTION,
89
+ features=tfds.features.FeaturesDict({
90
+ 'image/id': tf.int64,
91
+ 'image_filepath': tfds.features.Text(),
92
+ 'url': tfds.features.Text(),
93
+ 'image': tfds.features.Image(encoding_format='jpeg'),
94
+ 'texts': tfds.features.Sequence(tfds.features.Text()),
95
+ }),
96
+ # If there's a common (input, target) tuple from the
97
+ # features, specify them here. They'll be used if
98
+ # `as_supervised=True` in `builder.as_dataset`.
99
+ supervised_keys=None, # Set to `None` to disable
100
+ homepage='https://nocaps.org/',
101
+ citation=_CITATION,
102
+ )
103
+
104
+ def _split_generators(self, dl_manager: tfds.download.DownloadManager):
105
+ """Returns SplitGenerators."""
106
+ def group_by_id(data, image_dir):
107
+ id2caps = collections.defaultdict(list)
108
+ for ex in data.get('annotations', []):
109
+ id2caps[ex['image_id']].append(ex['caption'])
110
+
111
+ id_to_example = {}
112
+ for ex in data['images']:
113
+ id_to_example[ex['id']] = {
114
+ 'image/id': ex['id'],
115
+ 'image_filepath': os.path.join(
116
+ _FILEPATH, image_dir, ex['file_name']),
117
+ 'url': ex['coco_url'],
118
+ 'image': os.path.join(_FILEPATH, image_dir, ex['file_name']),
119
+ 'texts': id2caps[ex['id']] if ex['id'] in id2caps else ['N/A'],
120
+ }
121
+ return id_to_example
122
+
123
+ # Returns the Dict[split names, Iterator[Key, Example]]
124
+ with open(_VAL_FILES) as f:
125
+ val_data = group_by_id(json.load(f), 'validation')
126
+ with open(_TEST_FILES) as f:
127
+ test_data = group_by_id(json.load(f), 'test')
128
+ return {
129
+ 'val': self._generate_examples(val_data),
130
+ 'test': self._generate_examples(test_data),
131
+ }
132
+
133
+ def _generate_examples(self, data):
134
+ """Generate a tf.Example object.
135
+
136
+ This contains the image, objects, attributes, regions and relationships.
137
+
138
+ Args:
139
+ data: a dictionary with the image/id.
140
+
141
+ Yields:
142
+ (key, example) tuples from dataset. The example has format specified in
143
+ the above DatasetInfo.
144
+ """
145
+ for k, v in data.items():
146
+ try:
147
+ # Jpeg decode test to check early errors. The decoded images are not
148
+ # used, instead we rely on the default tfds.features.Image function.
149
+ unused_image = tf.io.read_file(v['image_filepath'])
150
+ unused_image = np.array(tf.image.decode_jpeg(unused_image))
151
+ except tf.errors.InvalidArgumentError:
152
+ # Unable to read image, skip this image and output download link.
153
+ logging.error('Unable to decode: curl -O %s', v['url'])
154
+ continue
155
+ except tf.errors.NotFoundError:
156
+ # Unable to read image, skip this image and output download link.
157
+ logging.error('File not found: curl -O %s', v['url'])
158
+ continue
159
+
160
+ yield k, v
Tipsomaly/model/big_vision/datasets/refcoco/refcoco.py ADDED
@@ -0,0 +1,448 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # pylint: disable=line-too-long
16
+ r"""Unbatch RefCOCO, RefCOCO+, RefCOCOg datasets in TFDS structure."""
17
+
18
+ # Based on tensorflow_datasets/datasets/ref_coco
19
+
20
+ import io
21
+ import os
22
+ import pickle
23
+
24
+ import numpy as np
25
+ import PIL.Image
26
+ import pycocotools.coco
27
+ import tensorflow_datasets as tfds
28
+
29
+ _ROOT_PATH = '/tmp/data/'
30
+
31
+
32
+ class RefCocoConfig(tfds.core.BuilderConfig):
33
+ """Config to specify each RefCoco variant."""
34
+
35
+ def __init__(self, dataset, dataset_partition, **kwargs):
36
+ name = f'{dataset}_{dataset_partition}'
37
+ super(RefCocoConfig, self).__init__(name=name, **kwargs)
38
+ self.dataset = dataset
39
+ self.dataset_partition = dataset_partition
40
+
41
+
42
+ _DESCRIPTION = """RefCOCO, RefCOCO+, RefCOCOg datasets.
43
+
44
+ Images, boxes and segmentations are from the original COCO dataset
45
+ (Lin et al, ECCV 2014). The referential segmentations are from two different
46
+ sources:
47
+
48
+ 1) RefCOCOg (Mao et al, CVPR 2016):
49
+ - https://github.com/mjhucla/Google_Refexp_toolbox
50
+ - This is the split used in the "refcocog_google" dataset. Note that this
51
+ split has overlapping images in train/validation. The same split is also
52
+ provided in 2).
53
+
54
+ 2) Source of RefCOCO and RefCOCO+ (Yu et al, ECCV 2016):
55
+ - https://github.com/lichengunc/refer
56
+ - Apache License 2.0
57
+ - Provides all the splits used for generation of these datasets, including the
58
+ "refcocog_google" split that is identical with the split from 1).
59
+
60
+ For convenience, we provide an additional dataset "refcocox_combined" that
61
+ combines the datasets "refcoco_unc", "refcocoplus_unc", and "refcocog_umd",
62
+ unifying "testA" and "testB" into a single "test" split, and removing any images
63
+ from "train" that appear either in "validation" or "test".
64
+
65
+ Also for convenience, every split is unrolled twice (at the "objects" level and
66
+ at the "object/refs" level) and saved as "{split}_flat".
67
+ """
68
+
69
+ # pylint: disable=line-too-long
70
+ _CITATION = r"""
71
+ @inproceedings{DBLP:conf/cvpr/MaoHTCY016,
72
+ author = {Junhua Mao and
73
+ Jonathan Huang and
74
+ Alexander Toshev and
75
+ Oana Camburu and
76
+ Alan L. Yuille and
77
+ Kevin Murphy},
78
+ title = {Generation and Comprehension of Unambiguous Object Descriptions},
79
+ booktitle = {2016 {IEEE} Conference on Computer Vision and Pattern Recognition,
80
+ {CVPR} 2016, Las Vegas, NV, USA, June 27-30, 2016},
81
+ pages = {11--20},
82
+ publisher = {{IEEE} Computer Society},
83
+ year = {2016},
84
+ url = {https://doi.org/10.1109/CVPR.2016.9},
85
+ doi = {10.1109/CVPR.2016.9},
86
+ timestamp = {Fri, 24 Mar 2023 00:02:52 +0100},
87
+ biburl = {https://dblp.org/rec/conf/cvpr/MaoHTCY016.bib},
88
+ bibsource = {dblp computer science bibliography, https://dblp.org}
89
+ }
90
+
91
+ @inproceedings{DBLP:conf/eccv/YuPYBB16,
92
+ author = {Licheng Yu and
93
+ Patrick Poirson and
94
+ Shan Yang and
95
+ Alexander C. Berg and
96
+ Tamara L. Berg},
97
+ editor = {Bastian Leibe and
98
+ Jiri Matas and
99
+ Nicu Sebe and
100
+ Max Welling},
101
+ title = {Modeling Context in Referring Expressions},
102
+ booktitle = {Computer Vision - {ECCV} 2016 - 14th European Conference, Amsterdam,
103
+ The Netherlands, October 11-14, 2016, Proceedings, Part {II}},
104
+ series = {Lecture Notes in Computer Science},
105
+ volume = {9906},
106
+ pages = {69--85},
107
+ publisher = {Springer},
108
+ year = {2016},
109
+ url = {https://doi.org/10.1007/978-3-319-46475-6\_5},
110
+ doi = {10.1007/978-3-319-46475-6\_5},
111
+ timestamp = {Wed, 07 Dec 2022 23:10:23 +0100},
112
+ biburl = {https://dblp.org/rec/conf/eccv/YuPYBB16.bib},
113
+ bibsource = {dblp computer science bibliography, https://dblp.org}
114
+ }
115
+
116
+ @article{DBLP:journals/corr/LinMBHPRDZ14,
117
+ author = {Tsung{-}Yi Lin and
118
+ Michael Maire and
119
+ Serge J. Belongie and
120
+ Lubomir D. Bourdev and
121
+ Ross B. Girshick and
122
+ James Hays and
123
+ Pietro Perona and
124
+ Deva Ramanan and
125
+ Piotr Doll{\'{a}}r and
126
+ C. Lawrence Zitnick},
127
+ title = {Microsoft {COCO:} Common Objects in Context},
128
+ journal = {CoRR},
129
+ volume = {abs/1405.0312},
130
+ year = {2014},
131
+ url = {http://arxiv.org/abs/1405.0312},
132
+ archivePrefix = {arXiv},
133
+ eprint = {1405.0312},
134
+ timestamp = {Mon, 13 Aug 2018 16:48:13 +0200},
135
+ biburl = {https://dblp.org/rec/bib/journals/corr/LinMBHPRDZ14},
136
+ bibsource = {dblp computer science bibliography, https://dblp.org}
137
+ }
138
+ """
139
+
140
+ # coco_data = json.load(open('annotations/instances_train2017.json'))
141
+ # [l['name'] for l in coco_data['licenses']]
142
+ LICENSES = [
143
+ 'Attribution-NonCommercial-ShareAlike License',
144
+ 'Attribution-NonCommercial License',
145
+ 'Attribution-NonCommercial-NoDerivs License',
146
+ 'Attribution License',
147
+ 'Attribution-ShareAlike License',
148
+ 'Attribution-NoDerivs License',
149
+ 'No known copyright restrictions',
150
+ 'United States Government Work',
151
+ ]
152
+ # _licenses_map = {l['id']: i for i, l in enumerate(coco_data['licenses'])}
153
+ _licenses_map = {1: 0, 2: 1, 3: 2, 4: 3, 5: 4, 6: 5, 7: 6, 8: 7}
154
+
155
+ # pyformat: disable
156
+ # [c['name'] for c in coco_data['categories']]
157
+ CATEGORIES = [
158
+ 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train',
159
+ 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign',
160
+ 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
161
+ 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag',
162
+ 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite',
163
+ 'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
164
+ 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon',
165
+ 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot',
166
+ 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant',
167
+ 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
168
+ 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
169
+ 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
170
+ 'hair drier', 'toothbrush',
171
+ ]
172
+ # sorted(set(c['supercategory'] for c in coco_data['categories']))
173
+ SUPERCATEGORIES = [
174
+ 'accessory', 'animal', 'appliance', 'electronic', 'food', 'furniture',
175
+ 'indoor', 'kitchen', 'outdoor', 'person', 'sports', 'vehicle',
176
+ ]
177
+ # pyformat: enable
178
+
179
+
180
+ # Will be exported into directory `$TFDS_DATA_DIR/ref_coco_bv`
181
+ # If the class name was `RefCOCO` then it would be exported into
182
+ # `$TFDS_DATA_DIR/ref_coco`, which would collide with the default TFDS dataset
183
+ # also named `ref_coco` (which has precedence over `data_dir` builder arg).
184
+ class RefCocoBv(tfds.core.GeneratorBasedBuilder):
185
+ """DatasetBuilder for RefCoco datasets."""
186
+
187
+ VERSION = tfds.core.Version('1.4.0')
188
+ RELEASE_NOTES = {
189
+ '1.4.0': 'Added flat versions of all dataset splits.',
190
+ '1.3.0': 'Added "refcocox_combined" dataset.',
191
+ '1.2.0': 'Added "train_flat" splits.',
192
+ '1.1.0': 'Added more features (mask etc), nested "refs" in "objects".',
193
+ '1.0.0': 'Initial release.',
194
+ }
195
+
196
+ MANUAL_DOWNLOAD_INSTRUCTIONS = """
197
+ 1. Install https://pypi.org/project/pycocotools/.
198
+
199
+ 2. Download data (requires ~20G for COCO images):
200
+
201
+ (mkdir -p /tmp/tfds/downloads/manual &&
202
+ cd /tmp/tfds/downloads/manual &&
203
+ wget http://images.cocodataset.org/zips/train2017.zip &&
204
+ wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip &&
205
+ wget https://bvisionweb1.cs.unc.edu/licheng/referit/data/refcoco.zip &&
206
+ wget https://bvisionweb1.cs.unc.edu/licheng/referit/data/refcoco+.zip &&
207
+ wget https://bvisionweb1.cs.unc.edu/licheng/referit/data/refcocog.zip &&
208
+ for zip in *.zip; do unzip $zip; done
209
+ )
210
+
211
+ 3. Run the generation script with `TFDS_DATA_DIR=/tmp/tfds`
212
+ """
213
+
214
+ BUILDER_CONFIGS = [
215
+ RefCocoConfig(dataset='refcoco', dataset_partition='unc'),
216
+ RefCocoConfig(dataset='refcoco', dataset_partition='google'),
217
+ RefCocoConfig(dataset='refcocoplus', dataset_partition='unc'),
218
+ RefCocoConfig(dataset='refcocog', dataset_partition='google'),
219
+ RefCocoConfig(dataset='refcocog', dataset_partition='umd'),
220
+ RefCocoConfig(dataset='refcocox', dataset_partition='combined'),
221
+ ]
222
+
223
+ def _info(self) -> tfds.core.DatasetInfo:
224
+ return tfds.core.DatasetInfo(
225
+ builder=self,
226
+ features=tfds.features.FeaturesDict({
227
+ 'id': tfds.features.Scalar(np.int32),
228
+ 'image': tfds.features.Image(encoding_format='jpeg'),
229
+ 'height': tfds.features.Scalar(np.int32),
230
+ 'width': tfds.features.Scalar(np.int32),
231
+ 'license': tfds.features.ClassLabel(names=LICENSES),
232
+ 'file_name': tfds.features.Text(),
233
+ 'flickr_url': tfds.features.Text(),
234
+ 'coco_url': tfds.features.Text(),
235
+ 'objects': tfds.features.Sequence({
236
+ 'id': tfds.features.Scalar(np.int64),
237
+ 'area': tfds.features.Scalar(np.float32),
238
+ 'bbox': tfds.features.BBoxFeature(),
239
+ 'mask': tfds.features.Image(encoding_format='png'),
240
+ 'category': tfds.features.ClassLabel(names=CATEGORIES),
241
+ 'supercategory': tfds.features.ClassLabel(
242
+ names=SUPERCATEGORIES
243
+ ),
244
+ 'iscrowd': tfds.features.Scalar(np.bool_),
245
+ # refcoco, refcoco+, refcocog features:
246
+ 'refs': tfds.features.Sequence({
247
+ 'id': tfds.features.Scalar(np.int32),
248
+ 'sentence': tfds.features.Text(),
249
+ }),
250
+ }),
251
+ }),
252
+ supervised_keys=None, # Set to `None` to disable
253
+ citation=_CITATION,
254
+ description=_DESCRIPTION,
255
+ )
256
+
257
+ def _split_generators(self, dl_manager: tfds.download.DownloadManager):
258
+ allowed_splits = {
259
+ ('refcoco', 'google'): [
260
+ tfds.Split.TRAIN,
261
+ tfds.Split.VALIDATION,
262
+ tfds.Split.TEST,
263
+ ],
264
+ ('refcoco', 'unc'): [
265
+ tfds.Split.TRAIN,
266
+ tfds.Split.VALIDATION,
267
+ 'testA',
268
+ 'testB',
269
+ ],
270
+ ('refcocoplus', 'unc'): [
271
+ tfds.Split.TRAIN,
272
+ tfds.Split.VALIDATION,
273
+ 'testA',
274
+ 'testB',
275
+ ],
276
+ # Verified manually that image and annotation IDs match the ones in
277
+ # https://storage.googleapis.com/refexp/google_refexp_dataset_release.zip
278
+ ('refcocog', 'google'): [
279
+ tfds.Split.TRAIN,
280
+ tfds.Split.VALIDATION,
281
+ ],
282
+ ('refcocog', 'umd'): [
283
+ tfds.Split.TRAIN,
284
+ tfds.Split.VALIDATION,
285
+ tfds.Split.TEST,
286
+ ],
287
+ ('refcocox', 'combined'): [
288
+ tfds.Split.TRAIN,
289
+ tfds.Split.VALIDATION,
290
+ tfds.Split.TEST,
291
+ ],
292
+ }
293
+ bc = self.builder_config
294
+ splits = allowed_splits[(bc.dataset, bc.dataset_partition)]
295
+
296
+ data_dir = dl_manager.manual_dir
297
+ for url, components in (
298
+ # pylint: disable=line-too-long
299
+ # pyformat: disable
300
+ ('http://images.cocodataset.org/zips/train2017.zip', ('train2017', '000000147328.jpg')),
301
+ ('http://images.cocodataset.org/annotations/annotations_trainval2017.zip', ('annotations', 'instances_train2017.json')),
302
+ ('https://bvisionweb1.cs.unc.edu/licheng/referit/data/refcoco.zip', ('refcoco', 'refs(unc).p')),
303
+ ('https://bvisionweb1.cs.unc.edu/licheng/referit/data/refcoco+.zip', ('refcoco+', 'refs(unc).p')),
304
+ ('https://bvisionweb1.cs.unc.edu/licheng/referit/data/refcocog.zip', ('refcocog', 'refs(umd).p')),
305
+ # pyformat: enable
306
+ # pylint: enable=line-too-long
307
+ ):
308
+ path = os.path.exists(os.path.join(data_dir, *components))
309
+ if not path:
310
+ raise FileNotFoundError(
311
+ f'Could not find {path}: please download {url} and unzip into'
312
+ f' {data_dir}'
313
+ )
314
+
315
+ coco = pycocotools.coco.COCO(
316
+ os.path.join(data_dir, 'annotations', 'instances_train2017.json')
317
+ )
318
+
319
+ return {
320
+ split + suffix: self._generate_examples(
321
+ coco, data_dir, bc.dataset, bc.dataset_partition, split + suffix,
322
+ )
323
+ for split in splits
324
+ for suffix in ('', '_flat')
325
+ }
326
+
327
+ # Builder must overwrite all abstract methods.
328
+ def _generate_examples(
329
+ self, coco, data_dir, dataset, dataset_partition, split):
330
+ return _generate_examples(coco, data_dir, dataset, dataset_partition, split)
331
+
332
+
333
+ def _get_ids(data_dir, dataset, dataset_partition, split):
334
+ """Returns `img_ids, ann_to_refs` for specified dataset/partition/split."""
335
+
336
+ def load(dataset, dataset_partition):
337
+ fname = f'refs({dataset_partition}).p'
338
+ path = os.path.join(data_dir, dataset, fname)
339
+ refcoco = pickle.load(open(path, 'rb'))
340
+ return refcoco
341
+
342
+ if split == tfds.Split.VALIDATION:
343
+ split = 'val'
344
+
345
+ if (dataset, dataset_partition) == ('refcocox', 'combined'):
346
+ refcoco = (
347
+ load('refcocog', 'umd')
348
+ + load('refcoco', 'unc')
349
+ + load('refcoco+', 'unc')
350
+ )
351
+ if split == 'test':
352
+ splits = ('test', 'testA', 'testB')
353
+ else:
354
+ splits = (split,)
355
+
356
+ exclude_img_ids = set()
357
+ if split == 'train':
358
+ # Exclude all images with val/test annotations from train set.
359
+ exclude_img_ids = {
360
+ r['image_id'] for r in refcoco if r['split'] != 'train'
361
+ }
362
+ refcoco = [
363
+ r
364
+ for r in refcoco
365
+ if r['split'] in splits and r['image_id'] not in exclude_img_ids
366
+ ]
367
+
368
+ else:
369
+ if dataset == 'refcocoplus':
370
+ dataset = 'refcoco+'
371
+ refcoco = load(dataset, dataset_partition)
372
+ refcoco = [r for r in refcoco if r['split'] == split]
373
+
374
+ img_ids = {r['image_id'] for r in refcoco}
375
+ ann_to_refs = {}
376
+ for r in refcoco:
377
+ for sent in r['sentences']:
378
+ ann_to_refs.setdefault(r['ann_id'], []).append(dict(
379
+ id=sent['sent_id'],
380
+ sentence=sent['sent']
381
+ ))
382
+
383
+ return img_ids, ann_to_refs
384
+
385
+
386
+ def _generate_examples(coco, data_dir, dataset, dataset_partition, split):
387
+ """Generates examples for a given split."""
388
+
389
+ flat = '_flat' in split
390
+ split = split.replace('_flat', '')
391
+ img_ids, ann_to_refs = _get_ids(data_dir, dataset, dataset_partition, split)
392
+
393
+ for img_id in coco.getImgIds():
394
+
395
+ if img_id not in img_ids:
396
+ continue
397
+ img, = coco.loadImgs([img_id])
398
+
399
+ example = {
400
+ 'id': img_id,
401
+ 'image': os.path.join(data_dir, 'train2017', img['file_name']),
402
+ 'height': img['height'],
403
+ 'width': img['width'],
404
+ 'license': LICENSES[_licenses_map[img['license']]],
405
+ 'file_name': img['file_name'],
406
+ 'flickr_url': img['flickr_url'],
407
+ 'coco_url': img['coco_url'],
408
+ 'objects': [],
409
+ }
410
+ for ann in coco.loadAnns(coco.getAnnIds(img_id)):
411
+ refs = ann_to_refs.get(ann['id'])
412
+ if not refs:
413
+ continue
414
+ cat, = coco.loadCats([ann['category_id']])
415
+ mask = coco.annToMask(ann).astype(np.bool_)
416
+ mask_buf = io.BytesIO()
417
+ PIL.Image.fromarray(mask).save(mask_buf, 'png')
418
+ mask_buf.seek(0)
419
+ object_ = {
420
+ 'id': ann['id'],
421
+ 'mask': mask_buf,
422
+ 'category': cat['name'],
423
+ 'supercategory': cat['supercategory'],
424
+ 'iscrowd': ann['iscrowd'],
425
+ 'area': ann['area'],
426
+ 'bbox': _convert_bbox(img, *ann['bbox']),
427
+ 'refs': refs,
428
+ }
429
+ if flat:
430
+ example['objects'] = [object_]
431
+ for ref_i, ref in enumerate(refs):
432
+ object_['refs'] = [ref]
433
+ mask_buf.seek(0)
434
+ yield f'{img_id}_{ann["id"]}_{ref_i}', example
435
+ else:
436
+ example['objects'].append(object_)
437
+
438
+ if not flat:
439
+ yield img_id, example
440
+
441
+
442
+ def _convert_bbox(img, x, y, w, h):
443
+ return tfds.features.BBox(
444
+ ymin=y / img['height'],
445
+ xmin=x / img['width'],
446
+ ymax=(y + h) / img['height'],
447
+ xmax=(x + w) / img['width'],
448
+ )
Tipsomaly/model/big_vision/datasets/rsvqa_hr/rsvqa_hr.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # pylint: disable=line-too-long
16
+ r"""Implements RSVQA-HR dataset in TFDS.
17
+
18
+ Remote sensing visual question answering task, using high-resolution airborne
19
+ image data at 15cm resolution per pixel.
20
+
21
+ It's small dataset at source (14G), so simple to run locally.
22
+ First, download and unzip the dataset from https://zenodo.org/records/6344367
23
+ and place it in /tmp/data/rsvqa_hr.
24
+
25
+ Then, run conversion locally (make sure to install tensorflow-datasets for the `tfds` util):
26
+
27
+ cd third_party/py/big_vision/datasets
28
+ env TFDS_DATA_DIR=/tmp/tfds tfds build --datasets=rsvqa_hr
29
+
30
+ Example to load:
31
+
32
+ import tensorflow_datasets as tfds
33
+ dataset = tfds.load('rsvqa_hr', split='train', data_dir='/tmp/tfds')
34
+
35
+ Dataset splits (all):
36
+ train: 625,340 examples/questions
37
+ val: 102,843 examples/questions
38
+ test: 222,684 examples/questions
39
+ test_2: 105,647 examples/questions (other area, unknown instrument)
40
+ Non-numeric data splits (nonum):
41
+ train: 371,834 examples/questions
42
+ val: 60,405 examples/questions
43
+ test: 131,468 examples/questions
44
+ test_2: 62,554 examples/questions
45
+
46
+ Note: due to image duplication with each question, the dataset size is
47
+ significatnly increased by the number of questions per image.
48
+
49
+ Recommended training splits:
50
+ train: train
51
+ minitrain: train[:5%]
52
+ eval: val
53
+ full_train: train+val
54
+ test: test
55
+
56
+ Image sizes: 512x512
57
+ Number of answers per question: 1
58
+ Question types distribution in train split:
59
+ - Area (area): 14.6% (integers, binned into {0m2, 1-10m2, 11-100m2, 101-1000m2, >1000m2})
60
+ - Comparison(comp): 33.5%
61
+ - Count (count): 26.0% (integers, not binned, maximum number of objects is 89)
62
+ - Presence (presence): 26.0%
63
+ """
64
+ import json
65
+ import os
66
+
67
+ import numpy as np
68
+ import tensorflow_datasets as tfds
69
+
70
+
71
+ _DESCRIPTION = """RSVQA-HR dataset."""
72
+
73
+ # pylint: disable=line-too-long
74
+ _CITATION = """
75
+ @article{Lobry_2020,
76
+ title={RSVQA: Visual Question Answering for Remote Sensing Data},
77
+ volume={58},
78
+ ISSN={1558-0644},
79
+ url={http://dx.doi.org/10.1109/TGRS.2020.2988782},
80
+ DOI={10.1109/tgrs.2020.2988782},
81
+ number={12},
82
+ journal={IEEE Transactions on Geoscience and Remote Sensing},
83
+ publisher={Institute of Electrical and Electronics Engineers (IEEE)},
84
+ author={Lobry, Sylvain and Marcos, Diego and Murray, Jesse and Tuia, Devis},
85
+ year={2020},
86
+ month=dec, pages={8555-8566} }
87
+ """
88
+ # pylint: enable=line-too-long
89
+
90
+ # When running locally (recommended), copy files as above an use these:
91
+ PATH = '/tmp/data/rsvqa_hr/'
92
+
93
+
94
+ class RsvqaHrConfig(tfds.core.BuilderConfig):
95
+ """Config to specify each variant."""
96
+
97
+ def __init__(self, nonum, **kwargs):
98
+ name = 'nonum' if nonum else 'all'
99
+ super(RsvqaHrConfig, self).__init__(name=name, **kwargs)
100
+ self.nonum = nonum
101
+
102
+
103
+ class RsvqaHr(tfds.core.GeneratorBasedBuilder):
104
+ """DatasetBuilder for RSVQA-HR dataset."""
105
+
106
+ VERSION = tfds.core.Version('1.0.2')
107
+ RELEASE_NOTES = {
108
+ '1.0.0': 'First release.',
109
+ '1.0.1': 'Rename binned values.',
110
+ '1.0.2': 'Removed explicit png image encoding.',
111
+ }
112
+
113
+ BUILDER_CONFIGS = [
114
+ RsvqaHrConfig(nonum=False),
115
+ RsvqaHrConfig(nonum=True),
116
+ ]
117
+
118
+ def _info(self):
119
+ """Returns the metadata."""
120
+
121
+ return tfds.core.DatasetInfo(
122
+ builder=self,
123
+ description=_DESCRIPTION,
124
+ features=tfds.features.FeaturesDict({
125
+ 'question_id': tfds.features.Scalar(np.int32),
126
+ 'filename': tfds.features.Text(),
127
+ 'image': tfds.features.Image(),
128
+ 'question': tfds.features.Text(),
129
+ 'question_type': tfds.features.Text(),
130
+ 'answers': tfds.features.Sequence(tfds.features.Text()),
131
+ 'raw_answers': tfds.features.Sequence(tfds.features.Text()),
132
+ }),
133
+ supervised_keys=None,
134
+ homepage='https://rsvqa.sylvainlobry.com/',
135
+ citation=_CITATION,
136
+ )
137
+
138
+ def _split_generators(self, dl_manager: tfds.download.DownloadManager):
139
+ """Returns SplitGenerators."""
140
+ return {
141
+ split: self._generate_examples(split)
142
+ for split in ('train', 'val', 'test', 'test_2')
143
+ }
144
+
145
+ def _generate_examples(self, split):
146
+ """Yields (key, example) tuples."""
147
+ if split == 'test_2':
148
+ split = 'test_phili'
149
+ questions_path = os.path.join(PATH + f'USGS_split_{split}_questions.json')
150
+ answers_path = os.path.join(PATH + f'USGS_split_{split}_answers.json')
151
+ images_path = os.path.join(PATH + 'Data')
152
+
153
+ with open(questions_path, 'r') as f:
154
+ questions = json.loads(f.read())['questions']
155
+ with open(answers_path, 'r') as f:
156
+ answers = json.loads(f.read())['answers']
157
+
158
+ for q, a in zip(questions, answers):
159
+ assert q['active'] == a['active']
160
+ if not q['active']:
161
+ continue
162
+ if self.builder_config.nonum and q['type'] in ('area', 'count'):
163
+ continue
164
+ assert q['answers_ids'][0] == a['id']
165
+ assert q['id'] == a['question_id']
166
+
167
+ filename = f'{q["img_id"]}.png'
168
+ yield q['id'], {
169
+ 'question_id': q['id'],
170
+ 'filename': filename,
171
+ 'image': os.path.join(images_path, filename),
172
+ 'question': q['question'],
173
+ 'question_type': q['type'],
174
+ 'answers': [bin_answer(a['answer'], q['type'])],
175
+ 'raw_answers': [a['answer']],
176
+ }
177
+
178
+
179
+ def bin_answer(answer, question_type):
180
+ """Bins answers into expected ranges."""
181
+ if question_type == 'area':
182
+ area = int(answer[:-2])
183
+ if area == 0:
184
+ return '0 m2'
185
+ elif area <= 10:
186
+ return 'between 1 m2 and 10 m2'
187
+ elif area <= 100:
188
+ return 'between 11 m2 and 100 m2'
189
+ elif area <= 1000:
190
+ return 'between 101 m2 and 1000 m2'
191
+ else:
192
+ return 'more than 1000 m2'
193
+ return answer
Tipsomaly/model/big_vision/datasets/rsvqa_lr/rsvqa_lr.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # pylint: disable=line-too-long
16
+ r"""Implements RSVQA-LR dataset in TFDS.
17
+
18
+ Remote sensing visual question answering task, using low-resolution satellite
19
+ (Sentinel-2) RGB channels data at 10m resolution per pixel.
20
+
21
+ It's small dataset at source (200M), so simple to run locally.
22
+ First, download and unzip the dataset from https://zenodo.org/records/6344334
23
+ and place it in /tmp/data/rsvqa_lr.
24
+
25
+ Then, run conversion locally (make sure to install tensorflow-datasets for the `tfds` util):
26
+
27
+ cd third_party/py/big_vision/datasets
28
+ env TFDS_DATA_DIR=/tmp/tfds tfds build --datasets=rsvqa_lr
29
+
30
+ Example to load:
31
+
32
+ import tensorflow_datasets as tfds
33
+ dataset = tfds.load('rsvqa_lr', split='train', data_dir='/tmp/tfds')
34
+
35
+ Dataset splits:
36
+ train: 57223 examples/questions
37
+ val: 10005 examples/questions
38
+ test: 10004 examples/questions
39
+ And the same splits are available excluding numeric questions:
40
+ train_nonum: 39441 examples/questions
41
+ val_nonum: 6782 examples/questions
42
+ test_nonum: 6782 examples/questions
43
+
44
+ Note: due to image duplication with each question, the dataset size is
45
+ significatnly increased by the number of questions per image.
46
+
47
+ Recommended training splits:
48
+ train: train
49
+ minitrain: train[:5%]
50
+ eval: val
51
+ full_train: train+val
52
+ test: test
53
+
54
+ Image sizes: 256x256
55
+ Number of answers per question: 1
56
+ Question types distribution in train split:
57
+ - Comparison(comp): 39.4%
58
+ - Count (count): 29.9% (integers, binned at evaluation into
59
+ {0, 1-10, 11-100, 101-1000, >10000})
60
+ - Presence (presence): 29.7%
61
+ - Rural/Urban (rural_urban): 1%
62
+ """
63
+ import io
64
+ import json
65
+ import os
66
+
67
+ import numpy as np
68
+ import tensorflow_datasets as tfds
69
+
70
+
71
+ _DESCRIPTION = """RSVQA-LR dataset."""
72
+
73
+ # pylint: disable=line-too-long
74
+ _CITATION = """
75
+ @article{Lobry_2020,
76
+ title={RSVQA: Visual Question Answering for Remote Sensing Data},
77
+ volume={58},
78
+ ISSN={1558-0644},
79
+ url={http://dx.doi.org/10.1109/TGRS.2020.2988782},
80
+ DOI={10.1109/tgrs.2020.2988782},
81
+ number={12},
82
+ journal={IEEE Transactions on Geoscience and Remote Sensing},
83
+ publisher={Institute of Electrical and Electronics Engineers (IEEE)},
84
+ author={Lobry, Sylvain and Marcos, Diego and Murray, Jesse and Tuia, Devis},
85
+ year={2020},
86
+ month=dec, pages={8555–8566} }
87
+ """
88
+ # pylint: enable=line-too-long
89
+
90
+ # When running locally (recommended), copy files as above an use these:
91
+ PATH = '/tmp/data/rsvqa_lr/'
92
+
93
+
94
+ class RsvqaLrConfig(tfds.core.BuilderConfig):
95
+ """Config to specify each variant."""
96
+
97
+ def __init__(self, nonum, **kwargs):
98
+ name = 'nonum' if nonum else 'all'
99
+ super(RsvqaLrConfig, self).__init__(name=name, **kwargs)
100
+ self.nonum = nonum
101
+
102
+
103
+ class RsvqaLr(tfds.core.GeneratorBasedBuilder):
104
+ """DatasetBuilder for RSVQA-LR dataset."""
105
+
106
+ VERSION = tfds.core.Version('1.0.2')
107
+ RELEASE_NOTES = {
108
+ '1.0.0': 'First release.',
109
+ '1.0.1': 'Rename binned values.',
110
+ '1.0.2': 'Removed explicit png image encoding.',
111
+ }
112
+
113
+ BUILDER_CONFIGS = [
114
+ RsvqaLrConfig(nonum=False),
115
+ RsvqaLrConfig(nonum=True),
116
+ ]
117
+
118
+ def _info(self):
119
+ """Returns the metadata."""
120
+
121
+ return tfds.core.DatasetInfo(
122
+ builder=self,
123
+ description=_DESCRIPTION,
124
+ features=tfds.features.FeaturesDict({
125
+ 'question_id': tfds.features.Scalar(np.int32),
126
+ 'filename': tfds.features.Text(),
127
+ 'image': tfds.features.Image(),
128
+ 'question': tfds.features.Text(),
129
+ 'question_type': tfds.features.Text(),
130
+ 'answers': tfds.features.Sequence(tfds.features.Text()),
131
+ 'raw_answers': tfds.features.Sequence(tfds.features.Text()),
132
+ }),
133
+ supervised_keys=None,
134
+ homepage='https://rsvqa.sylvainlobry.com/',
135
+ citation=_CITATION,
136
+ )
137
+
138
+ def _split_generators(self, dl_manager: tfds.download.DownloadManager):
139
+ """Returns SplitGenerators."""
140
+ return {
141
+ split: self._generate_examples(split)
142
+ for split in ('train', 'val', 'test')
143
+ }
144
+
145
+ def _generate_examples(self, split):
146
+ """Yields (key, example) tuples."""
147
+ questions_path = os.path.join(PATH + f'LR_split_{split}_questions.json')
148
+ answers_path = os.path.join(PATH + f'LR_split_{split}_answers.json')
149
+ images_path = os.path.join(PATH + 'Images_LR')
150
+
151
+ with open(questions_path, 'r') as f:
152
+ questions = json.loads(f.read())['questions']
153
+ with open(answers_path, 'r') as f:
154
+ answers = json.loads(f.read())['answers']
155
+
156
+ for q, a in zip(questions, answers):
157
+ assert q['active'] == a['active']
158
+ if not q['active']:
159
+ continue
160
+ if self.builder_config.nonum and q['type'] == 'count':
161
+ continue
162
+ assert q['answers_ids'] == [a['id']]
163
+ assert q['id'] == a['question_id']
164
+
165
+ filename = f'{q["img_id"]}.tif'
166
+ img = read_tif(os.path.join(images_path, filename))
167
+ yield q['id'], {
168
+ 'question_id': q['id'],
169
+ 'filename': filename,
170
+ 'image': img,
171
+ 'question': q['question'],
172
+ 'question_type': q['type'],
173
+ 'answers': [bin_answer(a['answer'], q['type'])],
174
+ 'raw_answers': [a['answer']],
175
+ }
176
+
177
+
178
+ def bin_answer(answer, question_type):
179
+ """Bins answers into expected ranges."""
180
+ if question_type == 'count':
181
+ count = int(answer)
182
+ if count == 0:
183
+ return '0'
184
+ elif count <= 10:
185
+ return 'between 1 and 10'
186
+ elif count <= 100:
187
+ return 'between 11 and 100'
188
+ elif count <= 1000:
189
+ return 'between 101 and 1000'
190
+ else:
191
+ return 'more than 1000'
192
+ return answer
193
+
194
+
195
+ def read_tif(path):
196
+ with open(path, 'rb') as f:
197
+ img = tfds.core.lazy_imports.tifffile.imread(io.BytesIO(f.read()))
198
+ return img.astype(np.uint8)
Tipsomaly/model/big_vision/datasets/scicap/scicap.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # pylint: disable=line-too-long
16
+ r"""Creates TFDS dataset for SciCap.
17
+
18
+ Preparing the data:
19
+ 1) mkdir /tmp/data/scicap && cd /tmp/data/scicap
20
+ 2) wget 'https://www.dropbox.com/s/t1sjqesl0pynaxo/scicap_data.zip?dl=0'
21
+ 3) unzip -UU 'scicap_data.zip?dl=0' && rm 'scicap_data.zip?dl=0'
22
+
23
+ Then, run conversion locally (make sure to install tensorflow-datasets for the `tfds` util):
24
+
25
+ cd big_vision/datasets
26
+ env TFDS_DATA_DIR=/tmp/tfds tfds build --datasets=scicap
27
+
28
+ Example to load:
29
+
30
+ import tensorflow_datasets as tfds
31
+ dataset = tfds.load('scicap', split='train', data_dir='/tmp/tfds')
32
+ """
33
+ # pylint: enable=line-too-long
34
+ import enum
35
+ import functools
36
+ import json
37
+ import os
38
+
39
+ import tensorflow_datasets as tfds
40
+
41
+
42
+ _DESCRIPTION = """SciCap dataset."""
43
+ _CITATION = """
44
+ @article{hsu2021scicap,
45
+ title={SciCap: Generating captions for scientific figures},
46
+ author={Hsu, Ting-Yao and Giles, C Lee and Huang, Ting-Hao'Kenneth'},
47
+ journal={arXiv preprint arXiv:2110.11624},
48
+ year={2021}
49
+ }
50
+ """
51
+
52
+ # When running locally (recommended), copy files as above an use these:
53
+ _SCICAP_DIR = "/tmp/data/scicap/scicap_data"
54
+
55
+
56
+ class ScicapSubset(enum.Enum):
57
+ """Versions of the SciCap dataset."""
58
+ SINGLE_SENTENCE = "single_sentence"
59
+ FIRST_SENTENCE = "first_sentence"
60
+ LEQ_100_TOKENS = "leq_100_tokens"
61
+
62
+ _SPLITS_TO_GENERATE = ["train", "test", "val"]
63
+ _CONFIG_TO_IDS_PATH = {
64
+ (ScicapSubset.SINGLE_SENTENCE, True): "Single-Sentence-Caption/Yes-Subfig",
65
+ (ScicapSubset.SINGLE_SENTENCE, False): "Single-Sentence-Caption/No-Subfig",
66
+ (ScicapSubset.FIRST_SENTENCE, True): "First-Sentence/Yes-Subfig",
67
+ (ScicapSubset.FIRST_SENTENCE, False): "First-Sentence/No-Subfig",
68
+ (ScicapSubset.LEQ_100_TOKENS, True):
69
+ "Caption-No-More-Than-100-Tokens/Yes-Subfig",
70
+ (ScicapSubset.LEQ_100_TOKENS, False):
71
+ "Caption-No-More-Than-100-Tokens/No-Subfig",
72
+ }
73
+ _SUBFIG_TO_PATH = {
74
+ True: "SciCap-Yes-Subfig-Img", False: "SciCap-No-Subfig-Img"
75
+ }
76
+
77
+
78
+ class ScicapConfig(tfds.core.BuilderConfig):
79
+ """"Configuration for SciCap caption length and subfigure inclusion."""
80
+
81
+ def __init__(self, *, subset: ScicapSubset, subfig: bool, **kwargs):
82
+ """Parameters specifying how the dataset will be processed.
83
+
84
+ Args:
85
+ subset: Subset of the Scicap data (see enum above).
86
+ subfig: Whether or not figure with subfigures are included.
87
+ **kwargs: Passed on to the constructor of `BuilderConfig`.
88
+ """
89
+ super(ScicapConfig, self).__init__(**kwargs)
90
+ self.subset = subset
91
+ self.subfig = subfig
92
+
93
+
94
+ @functools.cache
95
+ def _read_annotations(split: str, image_id: str):
96
+ """Reads annotations for a single file."""
97
+ path = os.path.join(_SCICAP_DIR, "SciCap-Caption-All", split)
98
+ fname = os.path.join(path, image_id + ".json")
99
+ with open(fname, "r") as fin:
100
+ return json.load(fin)
101
+
102
+
103
+ class Scicap(tfds.core.GeneratorBasedBuilder):
104
+ """DatasetBuilder for the SciCap dataset."""
105
+
106
+ VERSION = tfds.core.Version("1.0.0")
107
+ RELEASE_NOTES = {"1.0.0": "First release."}
108
+
109
+ BUILDER_CONFIGS = [
110
+ ScicapConfig(
111
+ name="single_sentence_subfig_yes",
112
+ description="Single sentence caption with subfigures allowed.",
113
+ subset=ScicapSubset.SINGLE_SENTENCE,
114
+ subfig=True
115
+ ),
116
+ ScicapConfig(
117
+ name="single_sentence_subfig_no",
118
+ description="Single sentence caption with subfigures not allowed.",
119
+ subset=ScicapSubset.SINGLE_SENTENCE,
120
+ subfig=False
121
+ ),
122
+ ScicapConfig(
123
+ name="first_sentence_subfig_yes",
124
+ description="First sentence of captions with subfigures allowed.",
125
+ subset=ScicapSubset.FIRST_SENTENCE,
126
+ subfig=True
127
+ ),
128
+ ScicapConfig(
129
+ name="first_sentence_subfig_no",
130
+ description="First sentence of captions with subfigures not allowed.",
131
+ subset=ScicapSubset.FIRST_SENTENCE,
132
+ subfig=False
133
+ ),
134
+ ScicapConfig(
135
+ name="leq_100_tokens_subfig_yes",
136
+ description="Captions with <= 100 tokens with subfigures allowed.",
137
+ subset=ScicapSubset.LEQ_100_TOKENS,
138
+ subfig=True
139
+ ),
140
+ ScicapConfig(
141
+ name="leq_100_tokens_subfig_no",
142
+ description=("Captions with <= 100 tokens with subfigures"
143
+ " not allowed."),
144
+ subset=ScicapSubset.LEQ_100_TOKENS,
145
+ subfig=False
146
+ ),
147
+ ]
148
+
149
+ def _info(self):
150
+ """Returns the metadata."""
151
+
152
+ return tfds.core.DatasetInfo(
153
+ builder=self,
154
+ description=_DESCRIPTION,
155
+ features=tfds.features.FeaturesDict({
156
+ "image/id": tfds.features.Text(),
157
+ "image/filename": tfds.features.Text(),
158
+ "image": tfds.features.Image(encoding_format="png"),
159
+ "caption/originally_extracted": tfds.features.Text(),
160
+ "caption/lowercase_and_token_and_remove_figure_index":
161
+ tfds.features.Text(),
162
+ "caption/normalized/basic_num": tfds.features.Text(),
163
+ "caption/normalized/advanced_equation_bracket":
164
+ tfds.features.Text(),
165
+ }),
166
+ supervised_keys=None,
167
+ homepage="https://github.com/tingyaohsu/SciCap",
168
+ citation=_CITATION,
169
+ )
170
+
171
+ def _split_generators(self, dl_manager: tfds.download.DownloadManager):
172
+ """Returns SplitGenerators."""
173
+ return {split: self._generate_examples(split)
174
+ for split in _SPLITS_TO_GENERATE}
175
+
176
+ def _generate_examples(self, split: str):
177
+ """Yields (key, example) tuples from test set."""
178
+ config_path = _CONFIG_TO_IDS_PATH[
179
+ (self.builder_config.subset, self.builder_config.subfig)]
180
+ image_path = os.path.join(
181
+ _SCICAP_DIR, _SUBFIG_TO_PATH[self.builder_config.subfig], split)
182
+ id_list_fname = os.path.join(
183
+ _SCICAP_DIR, "List-of-Files-for-Each-Experiments",
184
+ config_path, split, "file_idx.json")
185
+ with open(id_list_fname, "r") as fin:
186
+ split_images = json.load(fin)
187
+
188
+ for fname in split_images:
189
+ assert fname.endswith(".png")
190
+ image_id = fname[:-len(".png")]
191
+ annotations = _read_annotations(split, image_id)
192
+ yield fname, {
193
+ "image/id": image_id,
194
+ "image/filename": fname,
195
+ "image": os.path.join(image_path, fname),
196
+ "caption/originally_extracted": annotations["0-originally-extracted"],
197
+ "caption/lowercase_and_token_and_remove_figure_index":
198
+ annotations["1-lowercase-and-token-and-remove-figure-index"][
199
+ "caption"],
200
+ "caption/normalized/basic_num": annotations["2-normalized"][
201
+ "2-1-basic-num"]["caption"],
202
+ "caption/normalized/advanced_equation_bracket":
203
+ annotations["2-normalized"][
204
+ "2-2-advanced-euqation-bracket"]["caption"]
205
+ }
Tipsomaly/model/big_vision/datasets/science_qa/science_qa.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # pylint: disable=line-too-long
16
+ r"""Implements ScienceQA train/val/test-set in TFDS structure.
17
+
18
+ First, download the science QA dataset from their website https://scienceqa.github.io/#download
19
+ - mkdir -p /tmp/data/ScienceQA_DATA
20
+ - From Google Drive: https://drive.google.com/corp/drive/folders/1w8imCXWYn2LxajmGeGH_g5DaL2rabHev
21
+ Then, run conversion locally (make sure to install tensorflow-datasets for the `tfds` util):
22
+ - cd big_vision/datasets
23
+ - env TFDS_DATA_DIR=/tmp/tfds tfds build --datasets=science_qa
24
+
25
+ Example to load:
26
+
27
+ import tensorflow_datasets as tfds
28
+ dataset = tfds.load(
29
+ 'science_qa', split='train',
30
+ data_dir='/tmp/tfds')
31
+
32
+ """
33
+ import json
34
+ import os
35
+
36
+ import numpy as np
37
+ import tensorflow_datasets as tfds
38
+
39
+
40
+ _DESCRIPTION = """Sci QA test-set."""
41
+
42
+ # pylint: disable=line-too-long
43
+ _CITATION = """
44
+ @inproceedings{lu2022learn,
45
+ title={Learn to Explain: Multimodal Reasoning via Thought Chains for Science Question Answering},
46
+ author={Lu, Pan and Mishra, Swaroop and Xia, Tony and Qiu, Liang and Chang, Kai-Wei and Zhu, Song-Chun and Tafjord, Oyvind and Clark, Peter and Ashwin Kalyan},
47
+ booktitle={The 36th Conference on Neural Information Processing Systems (NeurIPS)},
48
+ year={2022}
49
+ }
50
+ """
51
+ # pylint: enable=line-too-long
52
+
53
+ # When running locally (recommended), copy files as above an use these:
54
+ _SCIQA_PATH = '/tmp/data/ScienceQA_DATA/'
55
+ # _IMAGE_COCO_PATH = '/tmp/data/val2014'
56
+
57
+ _ALPHABETS = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
58
+
59
+
60
+ class ScienceQA(tfds.core.GeneratorBasedBuilder):
61
+ """DatasetBuilder for ScienceQA dataset."""
62
+
63
+ VERSION = tfds.core.Version('1.0.0')
64
+ RELEASE_NOTES = {'1.0.0': 'First release.'}
65
+
66
+ def _info(self):
67
+ """Returns the metadata."""
68
+
69
+ return tfds.core.DatasetInfo(
70
+ builder=self,
71
+ description=_DESCRIPTION,
72
+ features=tfds.features.FeaturesDict({
73
+ 'question': tfds.features.Text(),
74
+ 'choices': tfds.features.Sequence(tfds.features.Text()),
75
+ 'answer': tfds.features.Scalar(np.int32),
76
+ 'hint': tfds.features.Text(),
77
+ 'task': tfds.features.Text(),
78
+ 'grade': tfds.features.Text(),
79
+ 'subject': tfds.features.Text(),
80
+ 'topic': tfds.features.Text(),
81
+ 'category': tfds.features.Text(),
82
+ 'skill': tfds.features.Text(),
83
+ 'lecture': tfds.features.Text(),
84
+ 'solution': tfds.features.Text(),
85
+ 'image': tfds.features.Image(encoding_format='png'),
86
+ 'indexed_choices': tfds.features.Text(),
87
+ 'indexed_answer': tfds.features.Text(),
88
+ }),
89
+ supervised_keys=None,
90
+ homepage='https://github.com/lupantech/ScienceQA/tree/main',
91
+ citation=_CITATION,
92
+ )
93
+
94
+ def _split_generators(self, dl_manager: tfds.download.DownloadManager):
95
+ """Returns SplitGenerators."""
96
+ return {
97
+ split: self._generate_examples(split)
98
+ for split in ('train', 'test', 'val')
99
+ }
100
+
101
+ def _generate_examples(self, split):
102
+ """Yields (key, example) tuples from test set."""
103
+ annot_fname = os.path.join(_SCIQA_PATH, 'problems.json')
104
+
105
+ with open(annot_fname, 'r') as f:
106
+ data = json.loads(f.read())
107
+
108
+ for k, v in data.items():
109
+ if v['split'] == split: # "split":"train"
110
+ image = v['image']
111
+ # Science QA contains the example without image as well. As this
112
+ # conversion is for VQA tasks, we dropped the examples without Image.
113
+ # TODO: Include the examples without image, and udpate the
114
+ # downstream pipeline to skip the examples without image, instead of
115
+ # doing it at pre-processing.
116
+ if image:
117
+ image = os.path.join(f'{_SCIQA_PATH}/{split}/{k}/', f'{image}')
118
+ else:
119
+ # image = None
120
+ continue
121
+ question = v['question']
122
+ choices = v['choices']
123
+ answer = v['answer']
124
+ hint = v['hint']
125
+ if not hint:
126
+ hint = 'N/A' # align with orignal github implementation
127
+ task = v['task']
128
+ grade = v['grade']
129
+ subject = v['subject']
130
+ topic = v['topic']
131
+ category = v['category']
132
+ skill = v['skill']
133
+ lecture = v['lecture']
134
+ solution = v['solution']
135
+ split = v['split']
136
+ indexed_choices = ', '.join(
137
+ f'({_ALPHABETS[i]}) {c}' for i, c in enumerate(choices)
138
+ )
139
+ indexed_answer = _ALPHABETS[int(answer)]
140
+ yield int(k), {
141
+ 'question': question,
142
+ 'choices': choices,
143
+ 'answer': answer,
144
+ 'hint': hint,
145
+ 'task': task,
146
+ 'grade': grade,
147
+ 'subject': subject,
148
+ 'topic': topic,
149
+ 'category': category,
150
+ 'skill': skill,
151
+ 'lecture': lecture,
152
+ 'solution': solution,
153
+ 'image': image,
154
+ 'indexed_choices': indexed_choices,
155
+ 'indexed_answer': indexed_answer,
156
+ }
Tipsomaly/model/big_vision/datasets/screen2words/screen2words.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # pylint: disable=line-too-long
16
+ r"""Creates TFDS dataset for Screen2words.
17
+
18
+
19
+ Preparing the data:
20
+ 1) mkdir /tmp/data/rico && cd /tmp/data/rico
21
+ 2) wget https://storage.googleapis.com/crowdstf-rico-uiuc-4540/rico_dataset_v0.1/unique_uis.tar.gz
22
+ 3) tar xvfz unique_uis.tar.gz && rm unique_uis.tar.gz
23
+ 4) git clone https://github.com/google-research-datasets/screen2words.git
24
+
25
+ Then, run conversion locally (make sure to install tensorflow-datasets for the `tfds` util):
26
+
27
+ cd big_vision/datasets
28
+ env TFDS_DATA_DIR=/tmp/tfds tfds build --datasets=screen2words
29
+
30
+ Example to load:
31
+
32
+ import tensorflow_datasets as tfds
33
+ dataset = tfds.load('screen2_words', split='train', data_dir='/tmp/tfds')
34
+ """
35
+ # pylint: enable=line-too-long
36
+ import collections
37
+ import csv
38
+ import os
39
+
40
+ import numpy as np
41
+ import tensorflow_datasets as tfds
42
+
43
+
44
+ _DESCRIPTION = """Screen2words dataset."""
45
+ _CITATION = """
46
+ @inproceedings{wang2021screen2words,
47
+ title={Screen2words: Automatic mobile UI summarization with multimodal
48
+ learning},
49
+ author={Wang, Bryan and
50
+ Li, Gang and
51
+ Zhou, Xin and
52
+ Chen, Zhourong and
53
+ Grossman, Tovi and
54
+ Li, Yang},
55
+ booktitle={The 34th Annual ACM Symposium on User Interface Software
56
+ and Technology},
57
+ pages={498--510},
58
+ year={2021}
59
+ }
60
+ """
61
+
62
+ # When running locally (recommended), copy files as above an use these:
63
+ _SCREEN2WORDS_DIR = "/tmp/data/rico/screen2words"
64
+ _RICO_DIR = "/tmp/data/rico/combined"
65
+
66
+
67
+ # (name, path) tuples for splits to be generated.
68
+ _SPLITS_TO_GENERATE = ["train", "dev", "test"]
69
+
70
+
71
+ class Screen2Words(tfds.core.GeneratorBasedBuilder):
72
+ """DatasetBuilder for the Screen2words dataset."""
73
+
74
+ VERSION = tfds.core.Version("1.0.0")
75
+ RELEASE_NOTES = {"1.0.0": "First release."}
76
+
77
+ def _info(self):
78
+ """Returns the metadata."""
79
+
80
+ return tfds.core.DatasetInfo(
81
+ builder=self,
82
+ description=_DESCRIPTION,
83
+ features=tfds.features.FeaturesDict({
84
+ "image/id": tfds.features.Scalar(np.int32),
85
+ "image/filename": tfds.features.Text(),
86
+ "image": tfds.features.Image(encoding_format="jpeg"),
87
+ "summary": tfds.features.Sequence(tfds.features.Text()),
88
+ }),
89
+ supervised_keys=None,
90
+ homepage="https://github.com/google-research-datasets/screen2words",
91
+ citation=_CITATION,
92
+ )
93
+
94
+ def _split_generators(self, dl_manager: tfds.download.DownloadManager):
95
+ """Returns SplitGenerators."""
96
+ return {split: self._generate_examples(split)
97
+ for split in _SPLITS_TO_GENERATE}
98
+
99
+ def _generate_examples(self, split: str):
100
+ """Yields (key, example) tuples from test set."""
101
+ id_list_fname = os.path.join(
102
+ _SCREEN2WORDS_DIR, "split", f"{split}_screens.txt")
103
+ with open(id_list_fname, "r") as fin:
104
+ split_ids = fin.readlines()
105
+
106
+ summaries_fname = os.path.join(_SCREEN2WORDS_DIR, "screen_summaries.csv")
107
+ summaries = collections.defaultdict(list)
108
+ with open(summaries_fname, "r") as fin:
109
+ for entry in csv.DictReader(fin):
110
+ summaries[int(entry["screenId"])].append(entry["summary"])
111
+
112
+ for line in split_ids:
113
+ line = line.strip()
114
+ image_id = int(line)
115
+ yield image_id, {
116
+ "image/id": image_id,
117
+ "image/filename": f"{image_id}.jpg",
118
+ "image": os.path.join(_RICO_DIR, f"{image_id}.jpg"),
119
+ "summary": summaries[image_id],
120
+ }
Tipsomaly/model/big_vision/datasets/stvqa/stvqa.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # pylint: disable=line-too-long
16
+ r"""Implements ST-VQA dataset in TFDS.
17
+
18
+ It's small data, so simple to run locally.
19
+ First, download and unzip the dataset from https://rrc.cvc.uab.es/?ch=11
20
+ and place it in /tmp/data/stvqa.
21
+
22
+ Then, run conversion locally (make sure to install tensorflow-datasets for the `tfds` util):
23
+
24
+ cd third_party/py/big_vision/datasets
25
+ env TFDS_DATA_DIR=/tmp/tfds tfds build --datasets=stvqa
26
+
27
+ Example to load:
28
+
29
+ import tensorflow_datasets as tfds
30
+ dataset = tfds.load('stvqa', split='train', data_dir='/tmp/tfds')
31
+
32
+ Dataset splits:
33
+ train: 23446 examples/questions (subset of original train)
34
+ val: 2628 examples/questions (subset of original train)
35
+ test: 4070 examples/questions (no answers)
36
+
37
+ Note: original source data has no val/holdout split, and we therefore split the
38
+ original train split (26074 examples/questions) by ourselves into train & val
39
+ splits.
40
+
41
+ Recommended training splits:
42
+ train: train
43
+ minitrain: train[:5%]
44
+ eval: val
45
+ fulltrain: train+val
46
+ """
47
+ import json
48
+ import os
49
+
50
+ from big_vision.datasets.stvqa import val_ids
51
+ import numpy as np
52
+ import tensorflow_datasets as tfds
53
+
54
+ _VAL_IDS = val_ids.PSEUDO_VAL_IMAGE_PATHS
55
+
56
+ _DESCRIPTION = """ST-VQA dataset."""
57
+
58
+ # pylint: disable=line-too-long
59
+ _CITATION = """
60
+ @inproceedings{Biten_2019,
61
+ title={Scene Text Visual Question Answering},
62
+ url={http://dx.doi.org/10.1109/ICCV.2019.00439},
63
+ DOI={10.1109/iccv.2019.00439},
64
+ booktitle={2019 IEEE/CVF International Conference on Computer Vision (ICCV)},
65
+ publisher={IEEE},
66
+ author={Biten, Ali Furkan and Tito, Ruben and Mafla, Andres and Gomez, Lluis and Rusinol, Marcal and Jawahar, C.V. and Valveny, Ernest and Karatzas, Dimosthenis},
67
+ year={2019},
68
+ month=oct }
69
+ """
70
+ # pylint: enable=line-too-long
71
+
72
+ # When running locally (recommended), copy files as above an use these:
73
+ _STVQA_PATH = '/tmp/data/stvqa/'
74
+
75
+
76
+ class Stvqa(tfds.core.GeneratorBasedBuilder):
77
+ """DatasetBuilder for ST-VQA dataset."""
78
+
79
+ VERSION = tfds.core.Version('1.2.0')
80
+ RELEASE_NOTES = {
81
+ '1.0.0': 'First release.',
82
+ '1.1.0': 'Switch to COCO high-res images and lower-case answers.',
83
+ '1.2.0': 'Rename pseudo splits and remove lower-case answers.',
84
+ }
85
+
86
+ def _info(self):
87
+ """Returns the metadata."""
88
+
89
+ return tfds.core.DatasetInfo(
90
+ builder=self,
91
+ description=_DESCRIPTION,
92
+ features=tfds.features.FeaturesDict({
93
+ 'question_id': tfds.features.Scalar(np.int32),
94
+ 'filename': tfds.features.Text(),
95
+ 'image': tfds.features.Image(encoding_format='jpeg'),
96
+ 'question': tfds.features.Text(),
97
+ 'answers': tfds.features.Sequence(tfds.features.Text()),
98
+ }),
99
+ supervised_keys=None,
100
+ homepage='https://rrc.cvc.uab.es/?ch=11',
101
+ citation=_CITATION,
102
+ )
103
+
104
+ def _split_generators(self, dl_manager: tfds.download.DownloadManager):
105
+ """Returns SplitGenerators."""
106
+ return {split: self._generate_examples(split)
107
+ for split in ('train', 'val', 'test')}
108
+
109
+ def _generate_examples(self, split):
110
+ """Yields (key, example) tuples."""
111
+ src_split = 'test' if split == 'test' else 'train'
112
+ annot_fname = os.path.join(_STVQA_PATH, f'{src_split}_task_3.json')
113
+ images_path = f'{src_split}{"_task3" if src_split == "test" else ""}_images'
114
+
115
+ with open(annot_fname, 'r') as f:
116
+ data = json.loads(f.read())
117
+
118
+ for x in data['data']:
119
+ if split == 'val' and x['file_path'] not in _VAL_IDS:
120
+ continue
121
+ elif split == 'train' and x['file_path'] in _VAL_IDS:
122
+ continue
123
+ image_path = os.path.join(_STVQA_PATH, images_path, x['file_path'])
124
+ # Always use high-res COCO images from train2014 directory.
125
+ if x['file_path'].startswith('coco-text'):
126
+ image_path = image_path.replace(os.path.join(images_path, 'coco-text'),
127
+ 'train2014')
128
+ yield x['question_id'], {
129
+ 'question_id': x['question_id'],
130
+ 'filename': x['file_path'],
131
+ 'image': image_path,
132
+ 'question': x['question'],
133
+ 'answers': x.get('answers', []),
134
+ }
Tipsomaly/model/big_vision/datasets/tallyqa/tallyqa.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Import TallyQA into TFDS format. Uses Visual Genome and COCO images.
16
+
17
+ It's small data, so simple to run locally. First, download all the data:
18
+
19
+ mkdir /tmp/data/ ; cd /tmp/data
20
+ wget http://images.cocodataset.org/zips/{train2014,val2014}.zip
21
+ wget https://cs.stanford.edu/people/rak248/VG_100K_2/images.zip
22
+ wget https://cs.stanford.edu/people/rak248/VG_100K_2/images2.zip
23
+ wget https://github.com/manoja328/tallyqa/blob/master/tallyqa.zip?raw=true
24
+ unzip *.zip
25
+
26
+ Then, update the PATHs below and run conversion locally like so (make sure to
27
+ install tensorflow-datasets for the `tfds` util):
28
+
29
+ cd big_vision/datasets
30
+ env TFDS_DATA_DIR=/tmp/tfds tfds build --datasets=tallyqa
31
+
32
+ Example to load:
33
+ import tensorflow_datasets as tfds
34
+ dataset = tfds.load('tallyqa', split='train', data_dir='/tmp/tfds')
35
+
36
+ The test split distinguishes between simple and complex questions. The train
37
+ split does not contain this information. We therefore set issimple to `-1` in
38
+ the train split to indicate it is not known.
39
+ """
40
+
41
+ import json
42
+
43
+ import numpy as np
44
+ import tensorflow_datasets as tfds
45
+
46
+
47
+ _TALLYQA_PATH = '/tmp/data/tallyQA/'
48
+ _VISUAL_GENOME_PATH = '/tmp/data/visual_genome/'
49
+
50
+ _COCO_PATH = '/tmp/data/coco/'
51
+
52
+
53
+ _DESCRIPTION = """
54
+ TallyQA: Answering Complex Counting Questions
55
+ Most counting questions in visual question answering (VQA) datasets are simple
56
+ and require no more than object detection. Here, we study algorithms for complex
57
+ counting questions that involve relationships between objects, attribute
58
+ identification, reasoning, and more. To do this, we created TallyQA, the world's
59
+ largest dataset for open-ended counting.
60
+ """
61
+
62
+ _CITATION = """
63
+ @inproceedings{acharya2019tallyqa,
64
+ title={TallyQA: Answering Complex Counting Questions},
65
+ author={Acharya, Manoj and Kafle, Kushal and Kanan, Christopher},
66
+ booktitle={AAAI},
67
+ year={2019}
68
+ }
69
+ """
70
+
71
+ _HOMEPAGE = 'https://github.com/manoja328/TallyQA_dataset'
72
+
73
+
74
+ class TallyQA(tfds.core.GeneratorBasedBuilder):
75
+ """Import TallyQA dataset."""
76
+
77
+ VERSION = tfds.core.Version('1.0.0')
78
+ RELEASE_NOTES = {'1.0.0': 'Initial release.'}
79
+ MANUAL_DOWNLOAD_INSTRUCTIONS = """
80
+ There are three parts which should be downloaded:
81
+ * TallyQA (train / test json files)
82
+ * Visual Genome images (needed for train and test split)
83
+ * COCO (2014) train / val images (only needed for train split)
84
+ """
85
+
86
+ def _info(self) -> tfds.core.DatasetInfo:
87
+ """Returns the dataset metadata."""
88
+ features = tfds.features.FeaturesDict({
89
+ 'image': tfds.features.Image(shape=(None, None, 3)),
90
+ 'image_id': tfds.features.Scalar(dtype=np.int32),
91
+ 'image_source': tfds.features.Text(),
92
+ 'question': tfds.features.Text(),
93
+ 'question_id': tfds.features.Scalar(dtype=np.int32),
94
+ 'answer': tfds.features.Scalar(dtype=np.int32),
95
+ 'issimple': tfds.features.Scalar(dtype=np.int32),
96
+ })
97
+
98
+ return tfds.core.DatasetInfo(
99
+ builder=self,
100
+ features=features,
101
+ description=_DESCRIPTION,
102
+ supervised_keys=None,
103
+ homepage=_HOMEPAGE,
104
+ citation=_CITATION,
105
+ )
106
+
107
+ def _split_generators(self, dl_manager: tfds.download.DownloadManager) -> ...:
108
+ """Call the function which defines the splits."""
109
+ del dl_manager
110
+ return {
111
+ 'train': self._generate_examples(split='train'),
112
+ 'test': self._generate_examples(split='test'),
113
+ }
114
+
115
+ def _generate_examples(self, split: str) -> ...:
116
+ tally_json_file = f'{_TALLYQA_PATH}/{split}.json'
117
+ with open(tally_json_file, 'r') as f:
118
+ tally_json = json.load(f)
119
+
120
+ for tally_qa in tally_json:
121
+ # The TallyQA images come from two sources: Visual Genome and COCO.
122
+ # Determine the correct dataset by inspecting the prefix.
123
+ filepath = tally_qa['image']
124
+ if filepath.startswith('VG_100K'):
125
+ filepath = _VISUAL_GENOME_PATH + filepath
126
+ elif filepath.startswith('train2014') or filepath.startswith('val2014'):
127
+ filepath = _COCO_PATH + filepath
128
+ else:
129
+ raise ValueError(f'Unknown image path: {filepath}')
130
+
131
+ tally_qa_dict = {
132
+ 'image': filepath,
133
+ 'image_id': tally_qa['image_id'],
134
+ 'image_source': tally_qa['data_source'],
135
+ 'question': tally_qa['question'],
136
+ 'question_id': tally_qa['question_id'],
137
+ 'answer': int(tally_qa['answer']),
138
+ }
139
+ if split == 'test':
140
+ # Field only present in test split.
141
+ tally_qa_dict.update({'issimple': tally_qa['issimple']})
142
+ else:
143
+ # In the train split, we set issimple to -1 to indicate it is not known.
144
+ tally_qa_dict.update({'issimple': -1})
145
+ tally_qa_id = f'{tally_qa_dict["image_id"]} / {tally_qa_dict["question_id"]}' # pylint: disable=line-too-long
146
+ yield tally_qa_id, tally_qa_dict
Tipsomaly/model/big_vision/datasets/textcaps/textcaps.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # pylint: disable=line-too-long
16
+ r"""Implements textcaps val-set in TFDS structure.
17
+
18
+ It's small data, so simple to run locally. First, copy the data to local disk:
19
+
20
+ mkdir -p /tmp/data/textcaps
21
+ cd /tmp/data/textcaps
22
+ curl -O https://dl.fbaipublicfiles.com/textvqa/data/textcaps/TextCaps_0.1_train.json
23
+ curl -O https://dl.fbaipublicfiles.com/textvqa/data/textcaps/TextCaps_0.1_val.json
24
+ curl -O https://dl.fbaipublicfiles.com/textvqa/data/textcaps/TextCaps_0.1_test.json
25
+ curl -O https://dl.fbaipublicfiles.com/textvqa/images/train_val_images.zip
26
+ curl -O https://dl.fbaipublicfiles.com/textvqa/images/test_images.zip
27
+ unzip train_val_images.zip
28
+ rm train_val_images.zip
29
+ unzip test_images.zip
30
+ rm test_images.zip
31
+
32
+ Then, run conversion locally (make sure to install tensorflow-datasets for the
33
+ `tfds` util):
34
+
35
+ cd big_vision/datasets
36
+ env TFDS_DATA_DIR=/tmp/tfds tfds build --datasets=textcaps
37
+
38
+
39
+ Example to load:
40
+
41
+ import tensorflow_datasets as tfds
42
+ dataset = tfds.load('text_caps', split='val', data_dir='/tmp/tfds')
43
+ """
44
+ import collections
45
+ import json
46
+ import os
47
+
48
+ from absl import logging
49
+ import numpy as np
50
+ import tensorflow_datasets as tfds
51
+
52
+
53
+ _DESCRIPTION = """TextCaps dataset."""
54
+
55
+ # pylint: disable=line-too-long
56
+ _CITATION = (
57
+ '@inproceedings{sidorov2019textcaps,'
58
+ 'title={TextCaps: a Dataset for Image Captioningwith Reading Comprehension},'
59
+ 'author={Sidorov, Oleksii and Hu, Ronghang and Rohrbach, Marcus and Singh, Amanpreet},'
60
+ 'journal={European Conference on Computer Vision},'
61
+ 'year={2020}}')
62
+ # pylint: enable=line-too-long
63
+
64
+ # When running locally (recommended), copy files as above an use these:
65
+ _FILEPATH = '/tmp/data/textcaps/'
66
+ _TRAIN_FILES = '/tmp/data/textcaps/TextCaps_0.1_train.json'
67
+ _VAL_FILES = '/tmp/data/textcaps/TextCaps_0.1_val.json'
68
+ _TEST_FILES = '/tmp/data/textcaps/TextCaps_0.1_test.json'
69
+
70
+
71
+ class TextCaps(tfds.core.GeneratorBasedBuilder):
72
+ """DatasetBuilder for TextCaps dataset."""
73
+
74
+ VERSION = tfds.core.Version('1.0.0')
75
+ RELEASE_NOTES = {
76
+ '1.0.0': 'Initial release.',
77
+ }
78
+
79
+ def _info(self) -> tfds.core.DatasetInfo:
80
+ """Returns the dataset metadata.
81
+
82
+ (tfds.core.DatasetInfo object)
83
+ These are the features of your dataset like images, labels, etc.
84
+ """
85
+
86
+ return tfds.core.DatasetInfo(
87
+ builder=self,
88
+ description=_DESCRIPTION,
89
+ features=tfds.features.FeaturesDict({
90
+ 'image/id': tfds.features.Text(),
91
+ 'image_filepath': tfds.features.Text(),
92
+ 'image': tfds.features.Image(encoding_format='jpeg'),
93
+ 'texts': tfds.features.Sequence(tfds.features.Text()),
94
+ }),
95
+ supervised_keys=None, # Set to `None` to disable
96
+ homepage='https://textvqa.org/textcaps/',
97
+ citation=_CITATION,
98
+ )
99
+
100
+ def _split_generators(self, dl_manager: tfds.download.DownloadManager):
101
+ """Returns SplitGenerators."""
102
+ def group_by_id(data, image_dir):
103
+ id_to_example = collections.defaultdict(list)
104
+ for ex in data:
105
+ id_to_example[ex['image_id']].append(ex)
106
+
107
+ for k, exs in id_to_example.items():
108
+ image_ids, image_names, texts = [], [], []
109
+ for ex in exs:
110
+ image_ids.append(ex['image_id'])
111
+ image_names.append(ex['image_name'])
112
+ if ex.get('caption_str'):
113
+ texts.append(ex.get('caption_str'))
114
+ assert len(set(image_ids)) == 1
115
+ assert len(set(image_names)) == 1
116
+ image_filepath = os.path.join(
117
+ _FILEPATH, image_dir, str(image_names[0])+'.jpg')
118
+ id_to_example[k] = {
119
+ 'image/id': image_ids[0],
120
+ 'image_filepath': image_filepath,
121
+ 'image': image_filepath,
122
+ 'texts': texts,
123
+ }
124
+ return id_to_example
125
+
126
+ # Returns the Dict[split names, Iterator[Key, Example]]
127
+ with open(_TRAIN_FILES) as f:
128
+ train_data = group_by_id(json.load(f)['data'], 'train_images')
129
+ with open(_VAL_FILES) as f:
130
+ val_data = group_by_id(json.load(f)['data'], 'train_images')
131
+ with open(_TEST_FILES) as f:
132
+ test_data = group_by_id(json.load(f)['data'], 'test_images')
133
+ return {
134
+ 'train': self._generate_examples(train_data),
135
+ 'val': self._generate_examples(val_data),
136
+ 'test': self._generate_examples(test_data),
137
+ }
138
+
139
+ def _generate_examples(self, data):
140
+ """Generate a tf.Example object.
141
+
142
+ This contains the image, objects, attributes, regions and relationships.
143
+
144
+ Args:
145
+ data: a dictionary with the image/id.
146
+
147
+ Yields:
148
+ (key, example) tuples from dataset. The example has format specified in
149
+ the above DatasetInfo.
150
+ """
151
+ for k, v in data.items():
152
+ yield k, v
Tipsomaly/model/big_vision/datasets/textvqa/textvqa.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # pylint: disable=line-too-long
16
+ r"""Implements textvqa in TFDS structure.
17
+
18
+ It's small data, so simple to run locally. First, copy the data to local disk:
19
+
20
+ mkdir -p /tmp/data/textvqa
21
+ cd /tmp/data/textvqa
22
+ curl -O https://dl.fbaipublicfiles.com/textvqa/images/train_val_images.zip
23
+ curl -O https://dl.fbaipublicfiles.com/textvqa/images/test_images.zip
24
+ curl -O https://dl.fbaipublicfiles.com/textvqa/data/TextVQA_0.5.1_train.json
25
+ curl -O https://dl.fbaipublicfiles.com/textvqa/data/TextVQA_0.5.1_val.json
26
+ curl -O https://dl.fbaipublicfiles.com/textvqa/data/TextVQA_0.5.1_test.json
27
+ # The Rosetta_OCR files are probably not needed.
28
+ # curl -O https://dl.fbaipublicfiles.com/textvqa/data/TextVQA_Rosetta_OCR_v0.2_train.json
29
+ # curl -O https://dl.fbaipublicfiles.com/textvqa/data/TextVQA_Rosetta_OCR_v0.2_val.json
30
+ # curl -O https://dl.fbaipublicfiles.com/textvqa/data/TextVQA_Rosetta_OCR_v0.2_test.json
31
+ unzip train_val_images.zip
32
+ rm train_val_images.zip
33
+ unzip test_images.zip
34
+ rm test_images.zip
35
+ # Background: at https://textvqa.org/dataset/ it says:
36
+ # "Note: Some of the images in OpenImages are rotated,
37
+ # please make sure to check the Rotation field in the Image IDs files
38
+ # for train and test."
39
+ curl -O https://storage.googleapis.com/openimages/2018_04/train/train-images-boxable-with-rotation.csv
40
+ curl -O https://storage.googleapis.com/openimages/2018_04/test/test-images-with-rotation.csv
41
+ mv train-images-boxable-with-rotation.csv train_images/rotation.csv
42
+ mv test-images-with-rotation.csv test_images/rotation.csv
43
+
44
+ Then, run conversion locally (make sure to install tensorflow-datasets for the `tfds` util):
45
+
46
+ cd big_vision/datasets
47
+ env TFDS_DATA_DIR=/tmp/tfds tfds build --datasets=textvqa
48
+
49
+ Example to load:
50
+
51
+ import tensorflow_datasets as tfds
52
+ dataset = tfds.load('textvqa', split='train', data_dir='/tmp/tfds')
53
+ """
54
+ import json
55
+ import os
56
+
57
+ from absl import logging
58
+ import numpy as np
59
+ import pandas as pd
60
+ import tensorflow as tf
61
+ import tensorflow_datasets as tfds
62
+
63
+
64
+ _DESCRIPTION = """TextVqa dataset."""
65
+
66
+ # pylint: disable=line-too-long
67
+ _CITATION = (
68
+ '@inproceedings{singh2019towards,'
69
+ 'title={Towards VQA Models That Can Read},'
70
+ 'author={Singh, Amanpreet and Natarjan, Vivek and Shah, Meet and Jiang, Yu and Chen, Xinlei and Parikh, Devi and Rohrbach, Marcus},'
71
+ 'booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},'
72
+ 'pages={8317-8326},'
73
+ 'year={2019}}'
74
+ )
75
+ # pylint: enable=line-too-long
76
+
77
+ # When running locally (recommended), copy files as above and use these:
78
+ _FILEPATH = '/tmp/data/textvqa/'
79
+ _TRAIN_FILES = '/tmp/data/textvqa/TextVQA_0.5.1_train.json'
80
+ _VAL_FILES = '/tmp/data/textvqa/TextVQA_0.5.1_val.json'
81
+ _TEST_FILES = '/tmp/data/textvqa/TextVQA_0.5.1_test.json'
82
+ _ROTATION_CSV = 'rotation.csv'
83
+
84
+
85
+ class TextVqa(tfds.core.GeneratorBasedBuilder):
86
+ """DatasetBuilder for textvqa dataset."""
87
+
88
+ VERSION = tfds.core.Version('1.0.1')
89
+ RELEASE_NOTES = {
90
+ '1.0.0': 'Initial release.',
91
+ '1.0.1': 'Undo rotation for known rotated images.',
92
+ }
93
+
94
+ def _info(self) -> tfds.core.DatasetInfo:
95
+ """Returns the dataset metadata.
96
+
97
+ (tfds.core.DatasetInfo object)
98
+ These are the features of your dataset like images, labels, etc.
99
+ """
100
+
101
+ return tfds.core.DatasetInfo(
102
+ builder=self,
103
+ description=_DESCRIPTION,
104
+ features=tfds.features.FeaturesDict({
105
+ 'image/id': tfds.features.Scalar(np.int32),
106
+ 'image_filepath': tfds.features.Text(),
107
+ 'image': tfds.features.Image(encoding_format='jpeg'),
108
+ 'question_id': tfds.features.Scalar(np.int32),
109
+ 'question': tfds.features.Text(),
110
+ 'answers': tfds.features.Sequence(tfds.features.Text()),
111
+ }),
112
+ supervised_keys=None, # Set to `None` to disable
113
+ homepage='https://textvqa.org/',
114
+ citation=_CITATION,
115
+ )
116
+
117
+ def _split_generators(self, dl_manager: tfds.download.DownloadManager):
118
+ """Returns SplitGenerators."""
119
+ def json_to_examples(data, image_dir):
120
+ # Load rotation csv.
121
+ logging.info('Processing %d items in %s', len(data), image_dir)
122
+ rot = pd.read_csv(os.path.join(_FILEPATH, image_dir, _ROTATION_CSV))
123
+ rotation_by_id = {}
124
+ for row in rot.itertuples():
125
+ rotation = int(row.Rotation) if not np.isnan(row.Rotation) else 0
126
+ rotation_by_id[row.ImageID] = rotation
127
+
128
+ examples = {}
129
+ for v in data:
130
+ image_id = str(v['image_id'])
131
+ image_filepath = os.path.join(_FILEPATH, image_dir, image_id + '.jpg')
132
+ question_id = v['question_id']
133
+ examples[question_id] = {
134
+ 'image/id': question_id,
135
+ 'image_filepath': image_filepath,
136
+ 'image': image_filepath,
137
+ 'rotation': rotation_by_id[image_id],
138
+ 'question_id': question_id,
139
+ 'question': v['question'],
140
+ 'answers': v.get('answers', []), # No answers in test set.
141
+ }
142
+ return examples
143
+
144
+ # Returns the Dict[split names, Iterator[Key, Example]]
145
+ with open(_TRAIN_FILES) as f:
146
+ train_data = json_to_examples(json.load(f)['data'], 'train_images')
147
+ with open(_VAL_FILES) as f:
148
+ # Validation images are stored in the train_images folder.
149
+ val_data = json_to_examples(json.load(f)['data'], 'train_images')
150
+ with open(_TEST_FILES) as f:
151
+ test_data = json_to_examples(json.load(f)['data'], 'test_images')
152
+ return {
153
+ 'train': self._generate_examples(train_data),
154
+ 'val': self._generate_examples(val_data),
155
+ 'test': self._generate_examples(test_data),
156
+ }
157
+
158
+ def _generate_examples(self, data):
159
+ """Generate a tf.Example object.
160
+
161
+ Args:
162
+ data: a dictionary with the image/id.
163
+
164
+ Yields:
165
+ (key, example) tuples from dataset. The example has format specified in
166
+ the above DatasetInfo.
167
+ """
168
+ for k, v in data.items():
169
+ # If the image is rotated, we undo the rotation here and re-encode.
170
+ image_bytes = open(v['image_filepath'], 'rb').read()
171
+ if v['rotation'] != 0:
172
+ rotation = v['rotation']
173
+ assert rotation % 90 == 0
174
+ turns = int(rotation / 90)
175
+ image = tf.image.decode_jpeg(image_bytes)
176
+ image_bytes = tf.io.encode_jpeg(
177
+ tf.image.rot90(image, turns), quality=100
178
+ ).numpy()
179
+ # If no rotation was needed, we just pass along the unchanged bytes.
180
+ v['image'] = image_bytes
181
+
182
+ # Now all rotation should have been accounted for. And we don't want to
183
+ # pass on the (now obsolete) rotation info as features.
184
+ del v['rotation']
185
+
186
+ yield k, v
Tipsomaly/model/big_vision/datasets/vizwizvqa/vizwizvqa.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # pylint: disable=line-too-long
16
+ r"""Implements VizWizVQA dataset in TFDS structure.
17
+
18
+ It's small data, so simple to run locally. First, copy the data to local disk:
19
+
20
+ mkdir -p /tmp/data/vizwizvqa
21
+
22
+ wget -O https://vizwiz.cs.colorado.edu/VizWiz_final/images/train.zip /tmp/data/vizwizvqa
23
+ wget -O https://vizwiz.cs.colorado.edu/VizWiz_final/images/val.zip /tmp/data/vizwizvqa
24
+ wget -O https://vizwiz.cs.colorado.edu/VizWiz_final/images/test.zip /tmp/data/vizwizvqa
25
+
26
+ Then, run conversion locally
27
+ (make sure to install tensorflow-datasets for the `tfds` util):
28
+
29
+ cd big_vision/datasets
30
+ env TFDS_DATA_DIR=/tmp/tfds tfds build --datasets=vizwizvqa
31
+
32
+ Example to load:
33
+
34
+ import tensorflow_datasets as tfds
35
+ dataset = tfds.load('vizwizvqa', split='train', data_dir='/tmp/tfds')
36
+ """
37
+ import json
38
+ import os
39
+
40
+ import numpy as np
41
+ import tensorflow_datasets as tfds
42
+
43
+
44
+ _DESCRIPTION = """VizWiz VQA Dataset."""
45
+
46
+ # pylint: disable=line-too-long
47
+ _CITATION = """
48
+ @inproceedings{gurari2018vizwiz,
49
+ title={Vizwiz grand challenge: Answering visual questions from blind people},
50
+ author={Gurari, Danna and Li, Qing and Stangl, Abigale J and Guo, Anhong and Lin, Chi and Grauman, Kristen and Luo, Jiebo and Bigham, Jeffrey P},
51
+ booktitle={Proceedings of the IEEE conference on computer vision and pattern recognition},
52
+ pages={3608--3617},
53
+ year={2018}
54
+ }
55
+ }
56
+ """
57
+ # pylint: enable=line-too-long
58
+
59
+ # When running locally (recommended), copy files as above an use these:
60
+ _VIZWIZVQA_PATH = '/tmp/data/vizwizvqa/'
61
+
62
+
63
+ class VizWizVQA(tfds.core.GeneratorBasedBuilder):
64
+ """DatasetBuilder for VizWizVQA dataset."""
65
+
66
+ VERSION = tfds.core.Version('1.0.0')
67
+ RELEASE_NOTES = {'1.0.0': 'First release.'}
68
+
69
+ def _info(self):
70
+ """Returns the metadata."""
71
+ return tfds.core.DatasetInfo(
72
+ builder=self,
73
+ description=_DESCRIPTION,
74
+ features=tfds.features.FeaturesDict({
75
+ 'question': tfds.features.Text(),
76
+ 'image/filename': tfds.features.Text(),
77
+ 'image': tfds.features.Image(encoding_format='jpeg'),
78
+ 'answers': tfds.features.Sequence(tfds.features.Text()),
79
+ # can be "yes" "no" and "maybe" strings
80
+ 'answer_confidences': tfds.features.Sequence(tfds.features.Text()),
81
+ 'answerable': tfds.features.Scalar(np.int32),
82
+ 'question_id': tfds.features.Scalar(np.int32),
83
+ }),
84
+ supervised_keys=None,
85
+ homepage='https://vizwiz.org/tasks-and-datasets/vqa/',
86
+ citation=_CITATION,
87
+ )
88
+
89
+ def _split_generators(self, dl_manager: tfds.download.DownloadManager):
90
+ """Returns SplitGenerators."""
91
+ return {split: self._generate_examples(split)
92
+ for split in ('val', 'train', 'test',)}
93
+
94
+ def _generate_examples(self, split: str):
95
+ """Yields (key, example) tuples from test set."""
96
+ annot_fname = os.path.join(_VIZWIZVQA_PATH, 'annotations', f'{split}.json')
97
+
98
+ with open(annot_fname, 'r') as f:
99
+ data = json.loads(f.read())
100
+
101
+ for v in data:
102
+
103
+ answers = []
104
+ answer_confidences = []
105
+
106
+ image_file = v['image']
107
+ answerable = -1
108
+ if split != 'test':
109
+ for answer in v['answers']:
110
+ # A couple of answers in the train set are empty strings.
111
+ if not answer['answer']:
112
+ continue
113
+ answers.append(answer['answer'])
114
+ answer_confidences.append(answer['answer_confidence'])
115
+ answerable = v['answerable']
116
+
117
+ question_id = image_file[:-4]
118
+ question_id = int(question_id.split('_')[-1])
119
+
120
+ yield v['image'], {
121
+ 'question': v['question'],
122
+ 'image/filename': image_file,
123
+ 'question_id': question_id,
124
+ 'image': os.path.join(_VIZWIZVQA_PATH, split, image_file),
125
+ 'answers': answers,
126
+ 'answer_confidences': answer_confidences,
127
+ 'answerable': answerable,
128
+ }
Tipsomaly/model/big_vision/datasets/vqa/vqa.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # pylint: disable=line-too-long
16
+ r"""Import VQAv2 into TFDS format. Uses coco-2014 images.
17
+
18
+ It's small data, so simple to run locally. First, download all the data:
19
+
20
+ mkdir /tmp/data/ ; cd /tmp/data
21
+ wget http://images.cocodataset.org/zips/{train2014,val2014,test2015}.zip
22
+ wget https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_{Train,Val,Test}_mscoco.zip
23
+ wget https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Annotations_{Train,Val}_mscoco.zip
24
+ unzip '*.zip'
25
+
26
+ Then, run conversion locally (make sure to install tensorflow-datasets for the `tfds` util):
27
+
28
+ cd big_vision/datasets
29
+ env TFDS_DATA_DIR=/tmp/tfds tfds build --datasets=vqa
30
+
31
+ It runs at around 750 examples/sec, so takes around 25min for the 1.2M questions.
32
+ Each question is an example; images are repeated, a bit wasteful, but disk is cheap.
33
+
34
+
35
+ Example to load:
36
+
37
+ import tensorflow_datasets as tfds
38
+ dataset = tfds.load('vqa', split='train', data_dir='/tmp/tfds')
39
+ """
40
+ import json
41
+ import os
42
+
43
+ import numpy as np
44
+ import tensorflow_datasets as tfds
45
+
46
+
47
+ _VQAV2_PATH = '/tmp/data'
48
+ _IMAGE_PATH = '/tmp/data'
49
+
50
+
51
+ _CITATION = (
52
+ '@InProceedings{balanced_vqa_v2,'
53
+ 'author = {Yash Goyal and Tejas Khot and '
54
+ 'Douglas Summers{-}Stay and Dhruv Batra and Devi Parikh},'
55
+ 'title = {Making the {V} in {VQA} Matter: Elevating the Role of Image'
56
+ 'Understanding in {V}isual {Q}uestion {A}nswering},'
57
+ 'booktitle = {Computer Vision and Pattern Recognition (CVPR)},'
58
+ 'year = {2017},}')
59
+
60
+
61
+ class Vqa(tfds.core.GeneratorBasedBuilder):
62
+ """DatasetBuilder for VQAv2 dataset."""
63
+
64
+ VERSION = tfds.core.Version('3.0.0')
65
+ RELEASE_NOTES = {'3.0.0': 'Format as needed for PaliGemma'}
66
+
67
+ def _info(self) -> tfds.core.DatasetInfo:
68
+ """Returns the metadata."""
69
+
70
+ return tfds.core.DatasetInfo(
71
+ builder=self,
72
+ description='The VQAv2 dataset.',
73
+ features=tfds.features.FeaturesDict({
74
+ 'image/id': np.int32,
75
+ 'image/filename': tfds.features.Text(),
76
+ 'image': tfds.features.Image(encoding_format='jpeg'),
77
+ 'question_id': np.int32,
78
+ 'question_type': tfds.features.Text(),
79
+ 'question_text': tfds.features.Text(),
80
+ 'answer_type': tfds.features.Text(),
81
+ 'answers': tfds.features.Sequence(tfds.features.Text()),
82
+ 'answer_confidences': tfds.features.Sequence(
83
+ tfds.features.ClassLabel(names=['no', 'maybe', 'yes'])),
84
+ 'top_answer': tfds.features.Text(),
85
+ }),
86
+ homepage='https://visualqa.org/',
87
+ citation=_CITATION,
88
+ )
89
+
90
+ def _split_generators(self, dl_manager: tfds.download.DownloadManager):
91
+ """Returns SplitGenerators."""
92
+ return {
93
+ 'train': self._generate_examples('train2014'),
94
+ 'validation': self._generate_examples('val2014'),
95
+ 'test': self._generate_examples('test2015'),
96
+ 'test-dev': self._generate_examples('test-dev2015', 'test2015'),
97
+ }
98
+
99
+ def _generate_examples(self, split, image_folder=None):
100
+ """Yields (key, example) tuples from test set."""
101
+ image_folder = image_folder or split
102
+
103
+ # The questions file has fields image_id, question, question_id.
104
+ with open(os.path.join(
105
+ _VQAV2_PATH, f'v2_OpenEnded_mscoco_{split}_questions.json')) as f:
106
+ examples = json.load(f)['questions']
107
+
108
+ # The questions file has fields: image_id, question_id, answers,
109
+ # answer_type, question_type, multiple_choice_answer.
110
+ if 'test' not in split:
111
+ with open(os.path.join(
112
+ _VQAV2_PATH, f'v2_mscoco_{split}_annotations.json')) as f:
113
+ annots = {a['question_id']: a for a in json.load(f)['annotations']}
114
+
115
+ for ex in examples:
116
+ qid = ex['question_id']
117
+ ex = {
118
+ 'image/id': ex['image_id'],
119
+ 'question_id': qid,
120
+ 'question_text': ex['question'],
121
+ }
122
+ if 'test' not in split:
123
+ fname = f'COCO_{image_folder}_{ex["image/id"]:012d}.jpg'
124
+ ex['image/filename'] = fname
125
+ ex['image'] = os.path.join(_IMAGE_PATH, image_folder, fname)
126
+ ann = annots[qid]
127
+ ex['question_type'] = ann['question_type']
128
+ ex['answer_type'] = ann['answer_type']
129
+ ex['answers'] = [a['answer'] for a in ann['answers']]
130
+ ex['answer_confidences'] = [a['answer_confidence']
131
+ for a in ann['answers']]
132
+ ex['top_answer'] = ann['multiple_choice_answer']
133
+ else:
134
+ # For test images, a few are from the wrong year...
135
+ fname = f'COCO_{image_folder}_{ex["image/id"]:012d}.jpg'
136
+ ex['image/filename'] = fname
137
+ if os.path.isfile(path := os.path.join(_IMAGE_PATH, image_folder, fname)):
138
+ ex['image'] = path
139
+ else:
140
+ print(ex['image/id'])
141
+ continue
142
+ ex['question_type'] = ''
143
+ ex['answer_type'] = ''
144
+ ex['answers'] = []
145
+ ex['answer_confidences'] = []
146
+ ex['top_answer'] = ''
147
+ yield qid, ex
Tipsomaly/model/big_vision/datasets/widgetcap/widgetcap.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # pylint: disable=line-too-long
16
+ r"""Import widgetcap into TFDS format.
17
+
18
+ Widget Captioning all requires images from the RICO dataset:
19
+ mkdir -p /tmp/data/rico_images ; cd /tmp/data/rico_images
20
+ wget
21
+ https://storage.googleapis.com/crowdstf-rico-uiuc-4540/rico_dataset_v0.1/unique_uis.tar.gz
22
+ tar xvfz unique_uis.tar.gz
23
+ rm unique_uis.tar.gz
24
+
25
+ Widget Captioning:
26
+ mkdir - /tmp/data/widget_captioning ; cd /tmp/data/widget_captioning
27
+ git clone https://github.com/google-research-datasets/widget-caption.git
28
+ cp widget-caption/widget_captions.csv ./
29
+ cp widget-caption/split/*.txt ./
30
+ rm -rf widget-caption
31
+
32
+ Then, run conversion locally (make sure to install tensorflow-datasets for the
33
+ `tfds` util):
34
+
35
+ cd big_vision/datasets
36
+ env TFDS_DATA_DIR=/tmp/tfds tfds build --datasets=widgetcap
37
+
38
+ Example to load:
39
+
40
+ import tensorflow_datasets as tfds
41
+ dataset_augmented = tfds.load('widgetcap', split='train',
42
+ data_dir='/tmp/tfds')
43
+ """
44
+ import csv
45
+ import json
46
+ import os
47
+
48
+ import numpy as np
49
+ from PIL import Image
50
+ import tensorflow_datasets as tfds
51
+
52
+ _DATASET_DIR = '/tmp/data/widget_captioning'
53
+ # Dataset property indicating the y-dim of the canvas
54
+ _RICO_CANVAS_Y = 2560
55
+ _IMAGE_DIR = '/tmp/data/rico_images/combined'
56
+
57
+ _CITATION = (
58
+ '@inproceedings{Li2020WidgetCG,title={Widget Captioning: Generating Natural'
59
+ ' Language Description for MobileUser Interface Elements},author={Y. Li and'
60
+ ' Gang Li and Luheng He and Jingjie Zheng and Hong Li and Zhiwei'
61
+ ' Guan},booktitle={Conference on Empirical Methods in Natural Language'
62
+ ' Processing},year={2020},}'
63
+ )
64
+
65
+
66
+ class Widgetcap(tfds.core.GeneratorBasedBuilder):
67
+ """DatasetBuilder for widgetcap dataset."""
68
+
69
+ VERSION = tfds.core.Version('1.0.0')
70
+ RELEASE_NOTES = {'1.0.0': 'Format as needed for PaliGemma'}
71
+
72
+ def _info(self) -> tfds.core.DatasetInfo:
73
+ """Returns the metadata."""
74
+
75
+ return tfds.core.DatasetInfo(
76
+ builder=self,
77
+ description='The widgetcap dataset.',
78
+ features=tfds.features.FeaturesDict({
79
+ 'image/id': tfds.features.Text(),
80
+ 'image/filename': tfds.features.Text(),
81
+ 'image': tfds.features.Image(encoding_format='jpeg'),
82
+ 'texts': tfds.features.Sequence(tfds.features.Text()),
83
+ 'bbox': tfds.features.BBoxFeature(),
84
+ 'screen_id': tfds.features.Text(),
85
+ 'node_id': tfds.features.Text(),
86
+ 'height': np.int32,
87
+ 'width': np.int32,
88
+ }),
89
+ homepage='https://github.com/google-research-datasets/widget-caption',
90
+ citation=_CITATION,
91
+ )
92
+
93
+ def _split_generators(self, dl_manager: tfds.download.DownloadManager):
94
+ """Returns SplitGenerators."""
95
+ return {
96
+ 'train': self._generate_examples('train'),
97
+ 'dev': self._generate_examples('dev'),
98
+ 'test': self._generate_examples('test'),
99
+ }
100
+
101
+ def _generate_examples(self, split):
102
+ """Yields (key, example) tuples from the dataset."""
103
+ split_screen_ids = set()
104
+ with open(os.path.join(_DATASET_DIR, split + '.txt')) as f:
105
+ for line in f:
106
+ split_screen_ids.add(line.strip())
107
+
108
+ with open(os.path.join(_DATASET_DIR, 'widget_captions.csv')) as f:
109
+ for row in csv.DictReader(f):
110
+ if row['screenId'] in split_screen_ids:
111
+ id_, example = self._get_example(
112
+ row['screenId'], row['nodeId'], row['captions']
113
+ )
114
+ yield id_, example
115
+
116
+ def _get_node_box(self, screen_id, node_id, height):
117
+ index_list = [int(i) for i in node_id.split('.')[1:]]
118
+ with open(os.path.join(_IMAGE_DIR, screen_id + '.json')) as f:
119
+ view = json.load(f)
120
+ curr_node = view['activity']['root']
121
+ for index in index_list:
122
+ curr_node = curr_node['children'][index]
123
+ normalized_bounds = map(
124
+ lambda x: x * height / _RICO_CANVAS_Y, curr_node['bounds']
125
+ )
126
+ return normalized_bounds
127
+
128
+ def _get_example(self, screen_id, node_id, captions):
129
+ image = Image.open(os.path.join(_IMAGE_DIR, screen_id + '.jpg'))
130
+ width, height = image.size
131
+ # get bounding box coordinates
132
+ xmin, ymin, xmax, ymax = self._get_node_box(screen_id, node_id, height)
133
+
134
+ image_id = f'{screen_id}_{node_id}'
135
+ example = {
136
+ 'image/id': image_id,
137
+ 'image/filename': screen_id + '.jpg',
138
+ 'image': os.path.join(_IMAGE_DIR, screen_id + '.jpg'),
139
+ 'texts': captions.split('|'),
140
+ 'bbox': tfds.features.BBox(
141
+ ymin=ymin / height,
142
+ xmin=xmin / width,
143
+ ymax=ymax / height,
144
+ xmax=xmax / width,
145
+ ),
146
+ 'screen_id': screen_id,
147
+ 'node_id': node_id,
148
+ 'height': height,
149
+ 'width': width,
150
+ }
151
+ return image_id, example
Tipsomaly/model/big_vision/datasets/xgqa/xgqa.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # pylint: disable=line-too-long
16
+ """Generates xGQA in a TFDS-ready structure.
17
+
18
+ First, download the data:
19
+ mkdir -p /tmp/data/xgqa/annotations
20
+ wget https://raw.githubusercontent.com/e-bug/iglue/main/datasets/xGQA/annotations/zero_shot/testdev_balanced_questions_bn.json -P /tmp/data/xgqa/annotations
21
+ wget https://raw.githubusercontent.com/e-bug/iglue/main/datasets/xGQA/annotations/zero_shot/testdev_balanced_questions_de.json -P /tmp/data/xgqa/annotations
22
+ wget https://raw.githubusercontent.com/e-bug/iglue/main/datasets/xGQA/annotations/zero_shot/testdev_balanced_questions_en.json -P /tmp/data/xgqa/annotations
23
+ wget https://raw.githubusercontent.com/e-bug/iglue/main/datasets/xGQA/annotations/zero_shot/testdev_balanced_questions_id.json -P /tmp/data/xgqa/annotations
24
+ wget https://raw.githubusercontent.com/e-bug/iglue/main/datasets/xGQA/annotations/zero_shot/testdev_balanced_questions_ko.json -P /tmp/data/xgqa/annotations
25
+ wget https://raw.githubusercontent.com/e-bug/iglue/main/datasets/xGQA/annotations/zero_shot/testdev_balanced_questions_pt.json -P /tmp/data/xgqa/annotations
26
+ wget https://raw.githubusercontent.com/e-bug/iglue/main/datasets/xGQA/annotations/zero_shot/testdev_balanced_questions_ru.json -P /tmp/data/xgqa/annotations
27
+ wget https://raw.githubusercontent.com/e-bug/iglue/main/datasets/xGQA/annotations/zero_shot/testdev_balanced_questions_zh.json -P /tmp/data/xgqa/annotations
28
+ wget https://downloads.cs.stanford.edu/nlp/data/gqa/images.zip -P /tmp/data/xgqa/
29
+ unzip /tmp/data/xgqa/images.zip -d /tmp/data/xgqa/
30
+
31
+ Then, run conversion locally (make sure to install tensorflow-datasets for the `tfds` util):
32
+
33
+ cd big_vision/datasets
34
+ env TFDS_DATA_DIR=/tmp/tfds tfds build --datasets=xgqa
35
+
36
+ Example to load:
37
+
38
+ import tensorflow_datasets as tfds
39
+ dataset = tfds.load(
40
+ 'xgqa', split='test_zs_en',
41
+ data_dir='/tmp/tfds')
42
+ """
43
+ import json
44
+ import os
45
+
46
+ import tensorflow_datasets as tfds
47
+
48
+ _DESCRIPTION = """xGQA (uses GQA images)."""
49
+
50
+ # pylint: disable=line-too-long
51
+ _CITATION = (
52
+ '@inproceedings{pfeiffer-etal-2022-xgqa,'
53
+ 'title = "x{GQA}: Cross-Lingual Visual Question Answering",'
54
+ 'author = "Pfeiffer, Jonas and'
55
+ ' Geigle, Gregor and'
56
+ ' Kamath, Aishwarya and'
57
+ ' Steitz, Jan-Martin and'
58
+ ' Roth, Stefan and'
59
+ ' Vuli{\'c}, Ivan and'
60
+ ' Gurevych, Iryna",'
61
+ 'booktitle = "Findings of the Association for Computational Linguistics: '
62
+ 'ACL 2022",'
63
+ 'month = may,'
64
+ 'year = "2022",'
65
+ 'address = "Dublin, Ireland",'
66
+ 'publisher = "Association for Computational Linguistics",'
67
+ 'url = "https://aclanthology.org/2022.findings-acl.196",'
68
+ 'doi = "10.18653/v1/2022.findings-acl.196",'
69
+ 'pages = "2497--2511",'
70
+ '}'
71
+ )
72
+ # pylint: enable=line-too-long
73
+
74
+ # When running locally (recommended), copy files as above an use these:
75
+ _DATA_PATH = '/tmp/data/xgqa/'
76
+ _IMAGE_PATH = '/tmp/data/xgqa/images/'
77
+
78
+ LANGUAGES = frozenset(['bn', 'de', 'en', 'id', 'ko', 'pt', 'ru', 'zh'])
79
+
80
+
81
+ class XGQA(tfds.core.GeneratorBasedBuilder):
82
+ """DatasetBuilder for XGQA dataset."""
83
+
84
+ VERSION = tfds.core.Version('1.0.0')
85
+ RELEASE_NOTES = {'1.0.0': 'First release.'}
86
+
87
+ def _info(self):
88
+ """Returns the metadata."""
89
+
90
+ return tfds.core.DatasetInfo(
91
+ builder=self,
92
+ description=_DESCRIPTION,
93
+ features=tfds.features.FeaturesDict({
94
+ 'example_id': tfds.features.Text(),
95
+ 'image/id': tfds.features.Text(),
96
+ 'image': tfds.features.Image(encoding_format='jpeg'),
97
+ 'question': tfds.features.Text(),
98
+ 'answer': tfds.features.Text(),
99
+ }),
100
+ supervised_keys=None,
101
+ homepage='https://github.com/adapter-hub/xGQA',
102
+ citation=_CITATION,
103
+ )
104
+
105
+ def _split_generators(self, dl_manager: tfds.download.DownloadManager):
106
+ """Returns SplitGenerators."""
107
+ d = dict()
108
+ for l in LANGUAGES:
109
+ d.update({
110
+ f'test_zs_{l}': self._generate_examples('test', 'zero_shot', l),
111
+ f'test_fs_{l}': self._generate_examples('test', 'few_shot', l),
112
+ f'dev_fs_{l}': self._generate_examples('test', 'few_shot', l),
113
+ f'train_fs1_{l}': self._generate_examples('train_1', 'few_shot', l),
114
+ f'train_fs5_{l}': self._generate_examples('train_5', 'few_shot', l),
115
+ f'train_fs10_{l}': self._generate_examples('train_10', 'few_shot', l),
116
+ f'train_fs20_{l}': self._generate_examples('train_20', 'few_shot', l),
117
+ f'train_fs25_{l}': self._generate_examples('train_25', 'few_shot', l),
118
+ f'train_fs48_{l}': self._generate_examples('train_48', 'few_shot', l),
119
+ })
120
+ return d
121
+
122
+ def _generate_examples(self, split, num_shots, lang):
123
+ """Yields (key, example) tuples."""
124
+ # Loads the questions for each image.
125
+ if num_shots == 'few_shot':
126
+ file_path = os.path.join(_DATA_PATH, 'annotations', 'few_shot', lang,
127
+ f'{split}.json')
128
+ elif num_shots == 'zero_shot':
129
+ file_path = os.path.join(_DATA_PATH, 'annotations', 'zero_shot',
130
+ f'testdev_balanced_questions_{lang}.json')
131
+ else:
132
+ raise ValueError(f'Unknown num_shots: {num_shots}')
133
+ with open(file_path, 'r') as f:
134
+ entries = json.load(f)
135
+
136
+ # Make one entry per question-answer pair.
137
+ for question_id, question_data in entries.items():
138
+ example_id = f'{question_id}_{lang}'
139
+ yield example_id, {
140
+ 'example_id': example_id,
141
+ 'image/id': question_data['imageId'],
142
+ 'image': os.path.join(_IMAGE_PATH, f'{question_data["imageId"]}.jpg'),
143
+ 'question': question_data['question'],
144
+ 'answer': question_data['answer'],
145
+ }
Tipsomaly/model/big_vision/datasets/xm3600/xm3600.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # pylint: disable=line-too-long
16
+ r"""Generates XM3600 in a TFDS-ready structure.
17
+
18
+ First, download the captions from https://google.github.io/crossmodal-3600/ and the images from https://cocodataset.org/#download.
19
+ The coco Karpathy split is available at http://cs.stanford.edu/people/karpathy/deepimagesent/caption_datasets.zip:
20
+ mkdir -p /tmp/data/xm3600
21
+ wget https://google.github.io/crossmodal-3600/web-data/captions.zip -P /tmp/data/xm3600
22
+ unzip /tmp/data/xm3600/captions.zip -d /tmp/data/xm3600/
23
+ wget https://open-images-dataset.s3.amazonaws.com/crossmodal-3600/images.tgz ta-P /tmp/data/xm3600
24
+ mkdir /tmp/data/xm3600/images
25
+ tar -xzf /tmp/data/xm3600/images.tgz -C /tmp/data/xm3600/images
26
+
27
+ Then, run conversion locally (make sure to install tensorflow-datasets for the `tfds` util):
28
+
29
+ cd big_vision/datasets
30
+ env TFDS_DATA_DIR=/tmp/tfds tfds build --datasets=xm3600
31
+
32
+ Example to load:
33
+
34
+ import tensorflow_datasets as tfds
35
+ dataset = tfds.load(
36
+ 'xm3600', split='en',
37
+ data_dir='/tmp/tfds')
38
+ """
39
+
40
+ import json
41
+ import os.path
42
+
43
+ import tensorflow_datasets as tfds
44
+
45
+ _DESCRIPTION = """
46
+ COCO image + captions, translated from English to 35 languages (English incl.).
47
+ """
48
+
49
+ # pylint: disable=line-too-long
50
+ _CITATION = """
51
+ @inproceedings{thapliyal-etal-2022-crossmodal,
52
+ title = "Crossmodal-3600: A Massively Multilingual Multimodal Evaluation Dataset",
53
+ author = "Thapliyal, Ashish V. and
54
+ Pont Tuset, Jordi and
55
+ Chen, Xi and
56
+ Soricut, Radu",
57
+ editor = "Goldberg, Yoav and
58
+ Kozareva, Zornitsa and
59
+ Zhang, Yue",
60
+ booktitle = "Proceedings of the 2022 Conference on Empirical Methods in Natural Language Processing",
61
+ month = dec,
62
+ year = "2022",
63
+ address = "Abu Dhabi, United Arab Emirates",
64
+ publisher = "Association for Computational Linguistics",
65
+ url = "https://aclanthology.org/2022.emnlp-main.45",
66
+ doi = "10.18653/v1/2022.emnlp-main.45",
67
+ pages = "715--729",
68
+ }
69
+ """
70
+ # pylint: enable=line-too-long
71
+
72
+
73
+ _CAPTIONS_PATH = '/tmp/data/xm3600'
74
+ _IMAGES_PATH = '/tmp/data/xm3600/images'
75
+
76
+ XM3600_LANGUAGES = [
77
+ 'ar', 'bn', 'cs', 'da', 'de', 'el', 'en', 'es', 'fa', 'fi', 'fil', 'fr',
78
+ 'he', 'hi', 'hr', 'hu', 'id', 'it', 'ja', 'ko', 'mi', 'nl', 'no', 'pl',
79
+ 'pt', 'quz', 'ro', 'ru', 'sv', 'sw', 'te', 'th', 'tr', 'uk', 'vi', 'zh'
80
+ ]
81
+
82
+
83
+ class Xm3600(tfds.core.GeneratorBasedBuilder):
84
+ """DatasetBuilder for XM3600 dataset."""
85
+
86
+ VERSION = tfds.core.Version('1.0.1')
87
+ RELEASE_NOTES = {
88
+ '1.0.0': 'First release.',
89
+ '1.0.1': 'Add captions/tokenized feature to compute metrics (eg CIDEr).',
90
+ }
91
+
92
+ def _info(self):
93
+ """Returns the metadata."""
94
+
95
+ return tfds.core.DatasetInfo(
96
+ builder=self,
97
+ description=_DESCRIPTION,
98
+ features=tfds.features.FeaturesDict({
99
+ 'image/id': tfds.features.Text(),
100
+ 'image': tfds.features.Image(encoding_format='jpeg'),
101
+ 'captions': tfds.features.Sequence(tfds.features.Text()),
102
+ 'captions/tokenized': tfds.features.Sequence(tfds.features.Text()),
103
+ 'language': tfds.features.Text(),
104
+ }),
105
+ supervised_keys=None,
106
+ homepage='https://google.github.io/crossmodal-3600/',
107
+ citation=_CITATION,
108
+ )
109
+
110
+ def _split_generators(self, dl_manager: tfds.download.DownloadManager):
111
+ """Returns SplitGenerators."""
112
+ return {lang: self._generate_examples(lang) for lang in XM3600_LANGUAGES}
113
+
114
+ def _generate_examples(self, split: str):
115
+ """Yields (key, example) tuples from dataset."""
116
+ language = split
117
+
118
+ annot_fname = os.path.join(_CAPTIONS_PATH, 'captions.jsonl')
119
+ data = {}
120
+ tok_data = {}
121
+ with open(annot_fname, 'r') as f:
122
+ for line in f:
123
+ j = json.loads(line)
124
+ image_id = f'{j["image/key"]}_{language}'
125
+ captions = j[language]['caption']
126
+ data[image_id] = captions
127
+ tok_data[image_id] = j[language]['caption/tokenized']
128
+
129
+ for image_id, captions in data.items():
130
+ yield image_id, {
131
+ 'image/id': image_id,
132
+ 'image': os.path.join(_IMAGES_PATH, f'{image_id.split("_")[0]}.jpg'),
133
+ 'captions': captions,
134
+ 'captions/tokenized': tok_data[image_id],
135
+ 'language': language,
136
+ }
Tipsomaly/model/big_vision/evaluators/proj/cappa/perplexity.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Evaluator for perplexity of a model."""
16
+ from big_vision.evaluators import mean
17
+ import big_vision.utils as u
18
+ import jax.numpy as jnp
19
+
20
+
21
+ # Temporary global flag to facilitate backwards compatability. Will be removed
22
+ # by the end of year 2023.
23
+ API = 'jit'
24
+
25
+
26
+ def perplexity(predict_fn, normalize_by_seqlen):
27
+ """Returns a function that computes perplexity."""
28
+
29
+ def _perplexity_fn(train_state, batch, pad_token=0, **kw):
30
+ logits, _ = predict_fn(train_state, batch, **kw)
31
+
32
+ # Ignore perplexity on the padding label.
33
+ weights = jnp.where(batch['labels'] != pad_token, 1, 0).astype(jnp.float32)
34
+ if batch.get('label_masks') is not None:
35
+ weights = weights * batch['label_masks']
36
+
37
+ losses = u.weighted_softmax_xent(
38
+ logits=logits, labels=batch['labels'],
39
+ weights=weights, label_smoothing=0.0,
40
+ reduction=False, normalize=normalize_by_seqlen)
41
+
42
+ return {'perplexity': losses}
43
+ return _perplexity_fn
44
+
45
+
46
+ class Evaluator(mean.Evaluator):
47
+ """Perplexity evaluator."""
48
+
49
+ def __init__(self, predict_fn, *a, normalize_by_seqlen=False, **kw):
50
+ super().__init__(perplexity(predict_fn, normalize_by_seqlen), *a, **kw)
Tipsomaly/model/big_vision/evaluators/proj/cappa/scoring_classifier.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Scoring classifier.
16
+
17
+ This one is based on a generative perspective for image classification.
18
+ Here we input the image as well as all the tokenized labels to compute their
19
+ perplexity and select the one with minimum loss as the prediction.
20
+ """
21
+ import functools
22
+ from big_vision.datasets.imagenet import class_names as imagenet_class_names
23
+ from big_vision.evaluators import mean
24
+ from big_vision.pp import builder as pp_builder
25
+ import jax.numpy as jnp
26
+ import numpy as np
27
+
28
+ # Temporary global flag to facilitate backwards compatability. Will be removed
29
+ # by the end of year 2023.
30
+ API = "jit"
31
+
32
+
33
+ CLASS_NAMES = {
34
+ "imagenet2012": imagenet_class_names.CLIP_IMAGENET_CLASS_NAMES,
35
+ }
36
+
37
+
38
+ # As a separate function to cache result across instances.
39
+ @functools.lru_cache(maxsize=None)
40
+ def get_classes(dataset_name, pp_txt):
41
+ """Load the class label strings and tokenize them using pp_txt."""
42
+ pp_fn = pp_builder.get_preprocess_fn(pp_txt, log_data=False)
43
+ return np.array([pp_fn({"label": name})["labels"]
44
+ for name in CLASS_NAMES[dataset_name]])
45
+
46
+
47
+ def scoring(predict_fn, tokenized_labels):
48
+
49
+ def _scoring_fn(train_state, batch, *a, **kw):
50
+ batch = {"_label_tokens": tokenized_labels, **batch}
51
+ scores = predict_fn(train_state, batch, *a, **kw)
52
+ predictions = jnp.argmax(scores, axis=-1)
53
+ return {"prec@1": predictions == batch["label"]}
54
+
55
+ return _scoring_fn
56
+
57
+
58
+ class Evaluator(mean.Evaluator):
59
+ """Evaluator for classification accuracy based on scoring all classes."""
60
+
61
+ def __init__(self, predict_fn, data, pp_fn, pp_txt, *a, **kw):
62
+ cls_tokens = get_classes(data["name"], pp_txt)
63
+ super().__init__(scoring(predict_fn, cls_tokens), data, pp_fn, *a, **kw)
Tipsomaly/model/big_vision/evaluators/proj/distill/distance.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Evaluator for the classfication task."""
16
+ from functools import partial, lru_cache
17
+
18
+ from big_vision import input_pipeline
19
+ import big_vision.datasets.core as ds_core
20
+ import big_vision.pp.builder as pp_builder
21
+ import big_vision.utils as u
22
+
23
+ import einops
24
+ import jax
25
+ import jax.numpy as jnp
26
+ from jax.sharding import NamedSharding
27
+ from jax.sharding import PartitionSpec as P
28
+ import numpy as np
29
+
30
+ # Temporary global flag to facilitate backwards compatability. Will be removed
31
+ # by the end of year 2023.
32
+ API = 'jit'
33
+
34
+
35
+ def dist(student, teacher, kind, feat_axis=-1,
36
+ epsilon=1e-12, t=1, ls=0.0, k=1):
37
+ """Distance function used for distillation."""
38
+ diff = student - teacher
39
+ if kind == 'euclidean':
40
+ return jnp.sqrt(jnp.sum(diff * diff, axis=feat_axis) + epsilon)
41
+ elif kind == 'l2':
42
+ return jnp.sum(diff * diff, axis=feat_axis)
43
+ elif kind == 'hard':
44
+ pseudolabels = jnp.argmax(teacher, feat_axis)
45
+ pl = u.onehot(pseudolabels, teacher.shape[feat_axis])
46
+ if ls:
47
+ pl = (1.0 - ls) * pl + (ls / (pl.shape[-1] - 1)) * (1.0 - pl)
48
+ return u.softmax_xent(logits=student, labels=pl,
49
+ reduction=False, kl=True, axis=feat_axis)
50
+ elif kind == 'kl':
51
+ return t**2 * u.softmax_xent(
52
+ logits=student / t,
53
+ labels=jax.nn.softmax(teacher / t),
54
+ reduction=False, kl=True, axis=feat_axis)
55
+ elif kind == 'logsoftmax_euclidean':
56
+ logsoftmax_diff = (
57
+ jax.nn.log_softmax(student, axis=feat_axis) -
58
+ jax.nn.log_softmax(teacher, axis=feat_axis))
59
+ return jnp.sqrt(
60
+ jnp.sum(logsoftmax_diff * logsoftmax_diff, axis=feat_axis) + epsilon)
61
+ elif kind == 'agree':
62
+ def get_top_k(arr, k, ax):
63
+ return jax.lax.top_k(arr.swapaxes(ax, -1), k)[1].swapaxes(ax, -1)
64
+ return (get_top_k(student, k, feat_axis) ==
65
+ get_top_k(teacher, 1, feat_axis)).sum(feat_axis)
66
+ else:
67
+ assert False, f'Unknown kind of distance {kind}.'
68
+
69
+
70
+ @lru_cache(None)
71
+ def get_dist_fn(**kw):
72
+ return partial(dist, **kw)
73
+
74
+
75
+ # To avoid re-compiling the function for every new instance of the same
76
+ # evaluator on a different dataset!
77
+ @lru_cache(None)
78
+ def get_eval_fn(student_teacher_fwd, what, mesh, distances):
79
+ """Produces eval function, also applies pmap."""
80
+ @partial(jax.jit, out_shardings=NamedSharding(mesh, P()))
81
+ def _eval_fn(train_state, batch, mask):
82
+ (_, out_s), (_, out_t) = student_teacher_fwd(train_state, batch)
83
+ repr_s = u.tree_get(out_s, what[0])
84
+ repr_t = u.tree_get(out_t, what[1])
85
+
86
+ # Let's flatten any non-vectors (eg feature-maps).
87
+ repr_s = einops.rearrange(repr_s, 'b ... -> b (...)')
88
+ repr_t = einops.rearrange(repr_t, 'b ... -> b (...)')
89
+
90
+ all_ds = []
91
+ # NOTE: we're gathering and returning all ; if this becomes too slow, we
92
+ # can change to compute and return summary stats later on.
93
+ for dist_fn in distances:
94
+ ds = dist_fn(repr_s, repr_t)
95
+ all_ds.append(ds)
96
+ all_masks = mask
97
+ return all_ds, all_masks
98
+
99
+ return _eval_fn
100
+
101
+
102
+ class Evaluator:
103
+ """Distillation distance evaluator."""
104
+
105
+ def __init__(
106
+ self,
107
+ student_teacher_fwd,
108
+ data,
109
+ pp_fn,
110
+ distances,
111
+ what=('logits', 'logits'),
112
+ *,
113
+ devices,
114
+ **data_kw,
115
+ ):
116
+ data = ds_core.get(**data)
117
+ pp_fn = pp_builder.get_preprocess_fn(pp_fn)
118
+ prefetch = data_kw.pop('prefetch', 1)
119
+ self.ds, self.steps = input_pipeline.make_for_inference(
120
+ data.get_tfdata(ordered=True),
121
+ pp_fn,
122
+ num_ex_per_process=data.num_examples_per_process(),
123
+ **data_kw,
124
+ )
125
+ self.data_iter = input_pipeline.start_global(self.ds, devices, prefetch)
126
+ dist_fns = tuple(get_dist_fn(**dist) for dist in distances)
127
+ self.dist_names = [
128
+ '_'.join(f'{k}={v}' for k, v in dist.items()) for dist in distances
129
+ ]
130
+ mesh = jax.sharding.Mesh(devices, ('data',))
131
+ self.eval_fn = get_eval_fn(student_teacher_fwd, what, mesh, dist_fns)
132
+
133
+ def run(self, train_state):
134
+ """Computes all metrics."""
135
+ all_ds = [[] for _ in self.dist_names]
136
+ for _, batch in zip(range(self.steps), self.data_iter):
137
+ mask = batch.pop('_mask')
138
+ batch_ds, batch_ms = self.eval_fn(train_state, batch, mask)
139
+ # All results are a replicated array shaped as follows:
140
+ # (local_devices, per_device_batch_size, elem_shape...)
141
+ # with each local device's entry being identical.
142
+ # So let's just take the first one to the host as numpy.
143
+ batch_ms = np.array(batch_ms)
144
+ for i, val in enumerate(batch_ds):
145
+ all_ds[i].append(np.array(val)[batch_ms == 1])
146
+ for name, ds in zip(self.dist_names, all_ds):
147
+ ds = np.concatenate(ds)
148
+ yield f'{name}/all', ds
149
+ yield f'{name}/avg', np.mean(ds)
150
+ yield f'{name}/min', np.min(ds)
151
+ yield f'{name}/max', np.max(ds)
Tipsomaly/model/big_vision/evaluators/proj/givt/coco_panoptic.py ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """COCO17 panoptic evaluation.
16
+
17
+ jax.jit-compatible fork of the evaluator from evaluators/proj/uvim.
18
+ """
19
+ import functools
20
+ import itertools
21
+ import json
22
+ import os
23
+ import tempfile
24
+ import time
25
+ from typing import Any
26
+ import zipfile
27
+
28
+ from absl import flags
29
+ from absl import logging
30
+ from big_vision import input_pipeline
31
+ from big_vision import utils
32
+ from big_vision.datasets import core as ds_core
33
+ import big_vision.pp.builder as pp_builder
34
+ import jax
35
+ import jax.numpy as jnp
36
+ import numpy as np
37
+ from pycocotools.panopticapi import evaluation
38
+ import panopticapi_converters.twochannels2panoptic_coco_format as converter
39
+ import tensorflow as tf
40
+ import tensorflow_datasets as tfds
41
+
42
+ from tensorflow.io import gfile
43
+
44
+ # Temporary global flag to facilitate backwards compatability.
45
+ API = 'jit'
46
+
47
+ ROOT = os.environ.get('COCO_DATA_DIR', '.')
48
+
49
+ PANOPTIC_COCO_CATS_FILE = f'{ROOT}/panoptic_coco_categories.json'
50
+ PANOPTIC_2017 = {
51
+ 'train': f'{ROOT}/panoptic_train2017.json',
52
+ 'validation': f'{ROOT}/panoptic_val2017.json',
53
+ }
54
+
55
+ PANOPTIC_GT_ZIP = {
56
+ 'train': f'{ROOT}/panoptic_train2017.zip',
57
+ 'validation': f'{ROOT}/panoptic_val2017.zip',
58
+ }
59
+
60
+
61
+ # Note: global to avoid jax re-compiling across different evaluator instances.
62
+ @functools.cache
63
+ def _get_predict_fn(predict_fn, mesh=None):
64
+ """Wrapper for jit-compiled predict function."""
65
+
66
+ # `out_shardings` annotation is needed because of the `all_gather` ops in the
67
+ # pmap implementation.
68
+ @functools.partial(jax.jit,
69
+ out_shardings=jax.sharding.NamedSharding(
70
+ mesh, jax.sharding.PartitionSpec()))
71
+ def _run_predict_fn(train_state, batch):
72
+ """Run predict_fn and gather all outputs on all devices."""
73
+ y = predict_fn(train_state, batch)
74
+ res = {
75
+ 'image/id': batch['image/id'],
76
+ 'mask': batch['_mask'],
77
+ 'y': jnp.stack([y['semantics'], y['instances']], axis=-1),
78
+ }
79
+ return res
80
+ return _run_predict_fn
81
+
82
+
83
+ class Evaluator:
84
+ """Panoptic segmentation evaluator: calls official COCO API."""
85
+
86
+ def __init__(
87
+ self,
88
+ predict_fn,
89
+ pp_fn,
90
+ batch_size,
91
+ data=None,
92
+ cache_final=True,
93
+ cache_raw=False,
94
+ prefetch=1,
95
+ save_dir=None,
96
+ *,
97
+ devices,
98
+ ):
99
+ """Panoptic segmentation evaluator: calls official COCO API.
100
+
101
+ Args:
102
+ predict_fn: jit-compilable function, which accepts arbitrary dictionaries
103
+ of parameters and data, where the data dictionary is produced by the
104
+ `pp_fn`. It is expected to output a 2-channel mask, where the first
105
+ channel encodes semantics, and the second channel encodes instance ids.
106
+ pp_fn: Preprocessing function, sepcified as string.
107
+ batch_size: Batch size.
108
+ data: Dict specifying name and split of the data set. Defaults to the
109
+ standard COCO (2017).
110
+ cache_final: Whether to cache the data after preprocessing - see
111
+ input_pipeline for details.
112
+ cache_raw: Whether to cache the raw data - see input_pipline for details.
113
+ prefetch: Number of batches to prefetch
114
+ save_dir: Directory to save the results in.
115
+ devices: List of jax devices.
116
+ """
117
+ self.predict_fn = _get_predict_fn(
118
+ predict_fn, jax.sharding.Mesh(devices, ('devices',)))
119
+
120
+ data_specs = dict(name='coco/2017_panoptic',
121
+ data_dir=None, split='validation')
122
+ data_specs.update(data or {})
123
+ data = ds_core.get(**data_specs)
124
+ self.dataset, self.steps = input_pipeline.make_for_inference(
125
+ data.get_tfdata(ordered=True), batch_size=batch_size,
126
+ num_ex_per_process=data.num_examples_per_process(),
127
+ preprocess_fn=pp_builder.get_preprocess_fn(pp_fn),
128
+ cache_final=cache_final, cache_raw=cache_raw)
129
+ self.data_iter = input_pipeline.start_global(
130
+ self.dataset, devices, prefetch)
131
+
132
+ # Only process 0 runs conversion to png and calls into coco api.
133
+ if jax.process_index() == 0:
134
+ self.result_dir = tempfile.TemporaryDirectory()
135
+ (self.gt_folder, self.gt_json, self.categories_json,
136
+ self.remap, self.size_map) = _prepare_ground_truth(
137
+ data_specs['name'], data_specs['split'],
138
+ data_specs.get('data_dir'))
139
+ if save_dir:
140
+ self.save_dir = save_dir.format(workdir=flags.FLAGS.workdir)
141
+ gfile.makedirs(self.save_dir)
142
+ else:
143
+ self.save_dir = None
144
+
145
+ def _compute_png_predictions(
146
+ self, train_state: Any) -> Any:
147
+ """Computes predictions and converts then to png to optimize memory use."""
148
+ count = 0
149
+ logging.info('Panoptic eval: running inference.')
150
+ for batch in itertools.islice(self.data_iter, self.steps):
151
+ out = self.predict_fn(train_state, batch)
152
+
153
+ if jax.process_index():
154
+ continue
155
+
156
+ out = jax.device_get(out)
157
+ mask = out['mask']
158
+ pan_recs = out['y'][mask]
159
+ ids = out['image/id'][mask]
160
+
161
+ for pan_rec, image_id in zip(pan_recs, ids):
162
+ sem = pan_rec[..., 0]
163
+ ins = pan_rec[..., 1]
164
+
165
+ sem_remapped = np.array(sem)
166
+ for v in np.unique(sem):
167
+ sem_remapped[sem == v] = self.remap[v]
168
+ sem = sem_remapped
169
+
170
+ pan_mask = np.stack([sem, ins, np.zeros_like(sem)], axis=-1)
171
+ pan_mask = utils.put_cpu(pan_mask)
172
+ pan_mask = _resize_nearest(pan_mask, self.size_map[image_id])
173
+ pan_mask_png = tf.io.encode_png(pan_mask.astype('uint8')).numpy()
174
+
175
+ fname = f'{self.result_dir.name}/{image_id:012d}.png'
176
+ with open(fname, 'wb') as f:
177
+ f.write(pan_mask_png)
178
+ count += 1
179
+
180
+ logging.log_every_n_seconds(
181
+ logging.INFO, 'Panoptic eval: processed %i examples so far.', 30,
182
+ count)
183
+
184
+ if jax.process_index():
185
+ return None
186
+
187
+ logging.info('Panoptic eval: inference done. Processed %d examples.', count)
188
+ return self.result_dir
189
+
190
+ def run(self, train_state):
191
+ """Run panoptic segmentation evaluation.
192
+
193
+ Args:
194
+ train_state: pytree containing the model parameters.
195
+
196
+ Yields:
197
+ Tuples consisting of metric name and value.
198
+ """
199
+ # Note result_dir is constant, but files inside are mutated.
200
+ result_dir = self._compute_png_predictions(train_state)
201
+
202
+ if jax.process_index():
203
+ return
204
+
205
+ if self.save_dir:
206
+ gfile.RecursivelyCopyDir(result_dir.name, self.save_dir, overwrite=True)
207
+
208
+ with tempfile.TemporaryDirectory() as pred_folder, \
209
+ tempfile.NamedTemporaryFile(mode='w') as pred_json:
210
+
211
+ logging.info('Panoptic eval: running conversion.')
212
+ converter.converter(
213
+ source_folder=result_dir.name,
214
+ images_json_file=self.gt_json,
215
+ categories_json_file=self.categories_json,
216
+ segmentations_folder=pred_folder,
217
+ predictions_json_file=pred_json.name)
218
+ logging.info('Panoptic eval: conversion done.')
219
+
220
+ logging.info('Panoptic eval: running metrics computation.')
221
+ res = evaluation.pq_compute(gt_json_file=self.gt_json,
222
+ gt_folder=self.gt_folder,
223
+ pred_json_file=pred_json.name,
224
+ pred_folder=pred_folder)
225
+ logging.info('Panoptic eval: metrics computation done.')
226
+
227
+ for k in ['All', 'Stuff', 'Things']:
228
+ for m in ['pq', 'rq', 'sq']:
229
+ yield f'{k}_{m}', res[k][m]
230
+
231
+
232
+ def _prepare_ground_truth(dataset, split, data_dir):
233
+ if dataset == 'coco/2017_panoptic' and data_dir is None:
234
+ return _prepare_ground_truth_from_zipfiles(split)
235
+ else:
236
+ return _prepare_ground_truth_from_dataset(dataset, split, data_dir)
237
+
238
+
239
+ @functools.lru_cache(maxsize=None)
240
+ def _prepare_ground_truth_from_dataset(dataset, split, data_dir):
241
+ """Prepare ground truth from a tf.data.Dataset.
242
+
243
+ Args:
244
+ dataset: TFDS-compatible dataset specification.
245
+ split: Data set split to use.
246
+ data_dir: Folder containing the data
247
+
248
+ Returns:
249
+ A tuple containing the folder containing the ground-truth data, the
250
+ ground truth annotations loaded from json, the categories loaded form json,
251
+ a map for remapping, and a map mapping image id to image size.
252
+
253
+ """
254
+ tfds_dataset = tfds.builder(
255
+ dataset, data_dir=data_dir).as_dataset(split=split)
256
+
257
+ categories_json = _make_local_copy(PANOPTIC_COCO_CATS_FILE)
258
+ with gfile.GFile(categories_json, 'rb') as f:
259
+ categories = json.loads(f.read())
260
+
261
+ # Build map from tfds class ids to COCO class ids.
262
+ remap = {0: 0}
263
+ with gfile.GFile(categories_json, 'r') as f:
264
+ remap = {**remap, **{(i + 1): x['id'] for i, x in enumerate(categories)}}
265
+
266
+ gt_folder = tempfile.mkdtemp()
267
+ gfile.makedirs(gt_folder)
268
+ size_map = {}
269
+ annotations = []
270
+ images = []
271
+ for example in tfds_dataset:
272
+ image_id = int(example['image/id'])
273
+ panoptic_image = example['panoptic_image']
274
+ ann_ids = example['panoptic_objects']['id']
275
+ ann_labels = example['panoptic_objects']['label']
276
+ ann_iscrowd = example['panoptic_objects']['is_crowd']
277
+ ann_area = example['panoptic_objects']['area']
278
+
279
+ fname = f'{image_id:012d}.png'
280
+ with gfile.GFile(os.path.join(gt_folder, fname), 'wb') as f:
281
+ f.write(tf.io.encode_png(panoptic_image).numpy())
282
+
283
+ size_map[image_id] = (panoptic_image.shape[0], panoptic_image.shape[1])
284
+
285
+ segments_info = []
286
+ for i in range(len(ann_ids)):
287
+ segments_info.append({
288
+ 'id': int(ann_ids[i]),
289
+ 'category_id': remap[int(ann_labels[i] + 1)],
290
+ 'iscrowd': int(ann_iscrowd[i]),
291
+ 'area': int(ann_area[i]),
292
+ })
293
+
294
+ annotations.append({
295
+ 'file_name': str(fname),
296
+ 'image_id': int(image_id),
297
+ 'segments_info': segments_info
298
+ })
299
+ images.append({
300
+ 'id': image_id,
301
+ 'file_name': f'{image_id:012d}.jpg',
302
+ })
303
+
304
+ # Write annotations.json needed for pq_compute.
305
+ gt_json = os.path.join(gt_folder, 'annotations.json')
306
+ with gfile.GFile(gt_json, 'wb') as f:
307
+ f.write(json.dumps({
308
+ 'images': images,
309
+ 'annotations': annotations,
310
+ 'categories': categories,
311
+ }))
312
+
313
+ return gt_folder, gt_json, categories_json, remap, size_map
314
+
315
+
316
+ def _prepare_ground_truth_from_zipfiles(split):
317
+ """Prepare ground truth from coco zip files.
318
+
319
+ Args:
320
+ split: dataset split to prepare ground truth for.
321
+
322
+ Returns:
323
+ A tuple containing the folder containing the ground-truth data, the ground
324
+ truth annotations loaded from json, the categories loaded form json, a map
325
+ for remapping, and a map mapping image id to image size.
326
+ """
327
+ split_prefix = split.split('[')[0]
328
+ if split_prefix not in ('train', 'validation'):
329
+ raise ValueError(f'Split {split} not supported')
330
+
331
+ # The following 4 calls are cached. This allows to save significant time
332
+ # in use cases like sweeping predict_fn hparams on the same run.
333
+ gt_json = _make_local_copy(PANOPTIC_2017[split_prefix])
334
+ gt_folder = _make_local_unzip_copy(PANOPTIC_GT_ZIP[split_prefix])
335
+ categories_json = _make_local_copy(PANOPTIC_COCO_CATS_FILE)
336
+ image_ids = _list_image_ids('coco/2017_panoptic', split)
337
+
338
+ gt_folder = os.path.join(
339
+ gt_folder, 'panoptic_val2017'
340
+ if split_prefix == 'validation' else 'panoptic_train2017')
341
+
342
+ # Build map from tfds class ids to COCO class ids.
343
+ remap = {0: 0}
344
+ with gfile.GFile(categories_json, 'r') as f:
345
+ remap = {**remap, **{(i + 1): x['id'] for i, x in enumerate(json.load(f))}}
346
+
347
+ # Filters gt_json to contain only annotations for images in dataset.
348
+ with gfile.GFile(gt_json) as f:
349
+ data = json.load(f)
350
+ logging.info(
351
+ 'Panoptic eval: pre-filter %d annotations.',
352
+ len(data['annotations'])
353
+ )
354
+ data['images'] = [x for x in data['images'] if x['id'] in image_ids]
355
+ data['annotations'] = [
356
+ x for x in data['annotations'] if x['image_id'] in image_ids
357
+ ]
358
+ logging.info(
359
+ 'Panoptic eval: post-filter %d annotations.',
360
+ len(data['annotations'])
361
+ )
362
+ filtered_gt_json = tempfile.NamedTemporaryFile(delete=False).name
363
+ with open(filtered_gt_json, 'w') as f:
364
+ json.dump(data, f)
365
+
366
+ # Precompute images sizes.
367
+ size_map = {x['id']: (x['height'], x['width']) for x in data['images']}
368
+
369
+ return gt_folder, filtered_gt_json, categories_json, remap, size_map
370
+
371
+
372
+ @functools.lru_cache(maxsize=None)
373
+ def _list_image_ids(dataset, split):
374
+ d = tfds.load(dataset, split=split).map(lambda x: x['image/id'])
375
+ return frozenset(d.as_numpy_iterator())
376
+
377
+
378
+ @functools.lru_cache(maxsize=None)
379
+ def _make_local_copy(fname) -> str:
380
+ start = time.monotonic()
381
+ local_file = tempfile.NamedTemporaryFile(delete=False)
382
+ gfile.copy(fname, local_file.name, overwrite=True)
383
+ logging.info('Copy %s in %d seconds.', fname, time.monotonic() - start)
384
+ return local_file.name
385
+
386
+
387
+ @functools.lru_cache(maxsize=None)
388
+ def _make_local_unzip_copy(fname) -> str:
389
+ start = time.monotonic()
390
+ folder = tempfile.mkdtemp()
391
+ with tempfile.NamedTemporaryFile() as tmp_zip_file:
392
+ gfile.copy(fname, tmp_zip_file.name, overwrite=True)
393
+ with zipfile.ZipFile(tmp_zip_file.name, 'r') as f:
394
+ f.extractall(folder)
395
+ logging.info('Copy %s in %d seconds.', fname, time.monotonic() - start)
396
+ return folder
397
+
398
+
399
+ @utils.jit_cpu(static_argnums=(1,))
400
+ def _resize_nearest(image, shape):
401
+ return jax.image.resize(image, shape + image.shape[-1:], 'nearest')
Tipsomaly/model/big_vision/evaluators/proj/givt/nyu_depth.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Evaluation for NYU depth.
16
+
17
+ jax.jit-compatible fork of the evaluator from evaluators/proj/uvim.
18
+
19
+ At evaluation time the ground truth is cropped and clipped. Values outside of
20
+ the test crop or clipping range are not included in eval calculations.
21
+
22
+ In this evaluator, it is assume that the groud truth is already cropped, so the
23
+ entire image is evaluated. However, the evaluator does perform the clipping.
24
+
25
+ Reference implementations:
26
+ https://github.com/zhyever/Monocular-Depth-Estimation-Toolbox/blo(internal link)a0f341244260ff61541191a613dd74bc/depth/datasets/nyu.py
27
+ https://github.com/vinvino02/GLPDepth/blob/7f3c78df4ecd6e7c79fd0c4b73c95d61f4aa2121/code/utils/metrics.py
28
+ https://github.com/shariqfarooq123/AdaBins/blob/2fb686a66a304f0a719bc53d77412460af97fd61/evaluate.py
29
+ """
30
+
31
+ import functools
32
+ import itertools
33
+
34
+ from big_vision import input_pipeline
35
+ from big_vision import utils
36
+ from big_vision.datasets import core as ds_core
37
+ import big_vision.pp.builder as pp_builder
38
+ import jax
39
+ import jax.numpy as jnp
40
+ import numpy as np
41
+
42
+ # Temporary global flag to facilitate backwards compatability.
43
+ API = "jit"
44
+
45
+
46
+ # Note: global to avoid jax re-compiling across different evaluator instances.
47
+ @functools.cache
48
+ def _get_predict_fn(predict_fn, mesh=None):
49
+ """Wrapper for jit-compiled predict function."""
50
+
51
+ # `out_shardings` annotation is needed because of the `all_gather` ops in the
52
+ # pmap implementation.
53
+ @functools.partial(jax.jit,
54
+ out_shardings=jax.sharding.NamedSharding(
55
+ mesh, jax.sharding.PartitionSpec()))
56
+ def _run_predict_fn(train_state, batch):
57
+ """Run predict_fn and gather all outputs on all devices."""
58
+ pred = predict_fn(train_state, batch)
59
+ return {"mask": batch["_mask"],
60
+ "gt": jnp.squeeze(batch["ground_truth"], axis=-1),
61
+ "y": pred["depth"]}
62
+ return _run_predict_fn
63
+
64
+
65
+ class Evaluator:
66
+ """Evaluator for NYU depth."""
67
+
68
+ def __init__(self,
69
+ predict_fn,
70
+ pp_fn,
71
+ batch_size,
72
+ data,
73
+ cache_final=True,
74
+ cache_raw=False,
75
+ prefetch=1,
76
+ min_depth=1e-3,
77
+ max_depth=10,
78
+ *,
79
+ devices):
80
+ """Evaluator for NYU depth.
81
+
82
+ Args:
83
+ predict_fn: jit-compilable function, accepts arbitrary dictionaries of
84
+ parameters and data, where the data dictionary is produced by the
85
+ `pp_fn` op. It is expected to output a dict with `depth` containing an
86
+ 2D array with the predicted depth. The prediction is resized to the
87
+ ground_truth size with nearest neighbour.
88
+ pp_fn: Preprocessing function, sepcified as string. `pp_fn` must also
89
+ output a 'ground_truth' as a 2D array of ground truth. Fruther, it has
90
+ to apply a crop, if one wants to compute metrics with the eval crop
91
+ typically used for NYU Depth metrics.
92
+ batch_size: Batch size.
93
+ data: Dict specifying name and split of the data set. Defaults to the
94
+ standard COCO (2017).
95
+ cache_final: Whether to cache the data after preprocessing - see
96
+ input_pipeline for details.
97
+ cache_raw: Whether to cache the raw data - see input_pipline for details.
98
+ prefetch: Number of batches to prefetch
99
+ min_depth: Minimum depth value.
100
+ max_depth: Maximum depth value.
101
+ devices: List of jax devices.
102
+ """
103
+ self.min_depth = min_depth
104
+ self.max_depth = max_depth
105
+ self.predict_fn = _get_predict_fn(
106
+ predict_fn, jax.sharding.Mesh(devices, ("devices",)))
107
+
108
+ data = ds_core.get(**data)
109
+ self.dataset, self.steps = input_pipeline.make_for_inference(
110
+ data.get_tfdata(ordered=True), batch_size=batch_size,
111
+ num_ex_per_process=data.num_examples_per_process(),
112
+ preprocess_fn=pp_builder.get_preprocess_fn(pp_fn),
113
+ cache_final=cache_final, cache_raw=cache_raw)
114
+ self.data_iter = input_pipeline.start_global(
115
+ self.dataset, devices, prefetch)
116
+
117
+ def run(self, train_state):
118
+ """Run NYU depth eval.
119
+
120
+ Args:
121
+ train_state: pytree containing the model parameters.
122
+
123
+ Yields:
124
+ Tuples consisting of metric name and value.
125
+ """
126
+ rmses = []
127
+ abs_res = []
128
+ abs_logs = []
129
+ d1s = []
130
+ d2s = []
131
+ d3s = []
132
+ for batch in itertools.islice(self.data_iter, self.steps):
133
+ # Outputs is a dict with values shaped (gather/same, devices, batch, ...)
134
+ out = self.predict_fn(train_state, batch)
135
+
136
+ if jax.process_index(): # Host0 gets all preds and does eval.
137
+ continue
138
+
139
+ out = jax.device_get(out)
140
+ # Then the bool-indexing with mask resulting in flat (global_batch, ...)
141
+ out = jax.tree_map(lambda x: x[out["mask"]], out) # pylint:disable=cell-var-from-loop
142
+
143
+ for gt, pred in zip(out["gt"], out["y"]):
144
+ # put_cpu and conversion to numpy arrays below to avoid unwanted
145
+ # host-to-device transfers
146
+ pred, gt = utils.put_cpu((pred, gt))
147
+ pred = _resize_nearest(pred, (gt.shape[0], gt.shape[1]))
148
+ pred, gt = np.array(pred), np.array(gt)
149
+ valid_mask = np.logical_and(gt > self.min_depth, gt < self.max_depth)
150
+
151
+ rmses.append(_compute_rmse(gt[valid_mask], pred[valid_mask]))
152
+ abs_res.append(_compute_abs_re(gt[valid_mask], pred[valid_mask]))
153
+ abs_logs.append(_compute_abs_log(gt[valid_mask], pred[valid_mask]))
154
+ d1s.append(_compute_delta(gt[valid_mask], pred[valid_mask], order=1))
155
+ d2s.append(_compute_delta(gt[valid_mask], pred[valid_mask], order=2))
156
+ d3s.append(_compute_delta(gt[valid_mask], pred[valid_mask], order=3))
157
+
158
+ if jax.process_index(): # Host0 gets all preds and does eval.
159
+ return
160
+
161
+ yield "RMSE", np.mean(rmses)
162
+ yield "abs_RE", np.mean(abs_res)
163
+ yield "log10", np.mean(abs_logs)
164
+ yield "delta1", np.mean(d1s)
165
+ yield "delta2", np.mean(d2s)
166
+ yield "delta3", np.mean(d3s)
167
+
168
+
169
+ @utils.jit_cpu(static_argnums=(1,))
170
+ def _resize_nearest(image, shape):
171
+ return jax.image.resize(image, shape, "nearest")
172
+
173
+
174
+ def _compute_rmse(gt, pred):
175
+ diff = gt - pred
176
+ return np.sqrt(np.mean(np.power(diff, 2)))
177
+
178
+
179
+ def _compute_abs_re(gt, pred):
180
+ diff = np.abs(gt - pred)
181
+ return np.mean(diff / gt)
182
+
183
+
184
+ def _compute_abs_log(gt, pred):
185
+ diff = np.abs(np.log10(gt) - np.log10(pred))
186
+ return np.mean(diff)
187
+
188
+
189
+ def _compute_delta(gt, pred, order):
190
+ rel_diff = np.maximum(gt / pred, pred / gt)
191
+ return np.sum(rel_diff < 1.25**order) / rel_diff.size
Tipsomaly/model/big_vision/evaluators/proj/givt/save_predictions.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Evaluator to save predictions."""
16
+ # pylint: disable=consider-using-from-import
17
+ import functools
18
+ import io # pylint: disable=unused-import
19
+ import itertools
20
+ import os
21
+
22
+ from absl import flags
23
+ from absl import logging
24
+ from big_vision import input_pipeline
25
+ from big_vision.datasets import core as ds_core
26
+ import big_vision.pp.builder as pp_builder
27
+ import big_vision.utils as u
28
+ import jax
29
+ import numpy as np
30
+
31
+ from tensorflow.io import gfile # pylint: disable=unused-import
32
+
33
+ # Temporary global flag to facilitate backwards compatability.
34
+ API = 'jit'
35
+
36
+
37
+ # Note: global to avoid jax re-compiling across different evaluator instances.
38
+ @functools.cache
39
+ def _get_predict_fn(predict_fn, mesh=None):
40
+ """Wrapper for jit-compiled predict function."""
41
+
42
+ # `out_shardings` annotation is needed because of the `all_gather` ops in the
43
+ # pmap implementation.
44
+ @functools.partial(jax.jit,
45
+ out_shardings=jax.sharding.NamedSharding(
46
+ mesh, jax.sharding.PartitionSpec()))
47
+ def _run_predict_fn(train_state, batch):
48
+ """Run predict_fn and gather all outputs on all devices."""
49
+ y = predict_fn(train_state, batch)
50
+ return {'inputs': batch, 'outputs': y, 'mask': batch['_mask']}
51
+ return _run_predict_fn
52
+
53
+
54
+ class Evaluator:
55
+ """Save predictions in "{FLAGS.workdir}/{outfile}".
56
+
57
+ Results can then be easily inspected in a notebook such as:
58
+
59
+ ```
60
+ results = utils.load_checkpoint("<full_path_to_outfile>")
61
+ inputs, outputs = (results["inputs"], results["outputs"])
62
+ ```
63
+ """
64
+
65
+ def __init__(self, predict_fn, pp_fn, batch_size, data, outfile,
66
+ cache_final=True, cache_raw=False, prefetch=1, *, devices):
67
+ self.predict_fn = _get_predict_fn(
68
+ predict_fn, jax.sharding.Mesh(devices, ('devices',)))
69
+
70
+ # Prepare data for each process and pad with zeros so all processes have the
71
+ # same number of batches.
72
+ data = ds_core.get(**data)
73
+ self.dataset, self.steps = input_pipeline.make_for_inference(
74
+ data.get_tfdata(ordered=True), batch_size=batch_size,
75
+ num_ex_per_process=data.num_examples_per_process(),
76
+ preprocess_fn=pp_builder.get_preprocess_fn(pp_fn),
77
+ cache_final=cache_final, cache_raw=cache_raw)
78
+ self.data_iter = input_pipeline.start_global(
79
+ self.dataset, devices, prefetch)
80
+
81
+ self.path = os.path.join(flags.FLAGS.workdir, outfile)
82
+
83
+ def run(self, train_state):
84
+ """Compute all predictions, gather in main host and save in outfile."""
85
+ count = 0
86
+ outputs = []
87
+ for batch in itertools.islice(self.data_iter, self.steps):
88
+ out = self.predict_fn(train_state, batch)
89
+ if jax.process_index():
90
+ continue
91
+
92
+ out = jax.device_get(out)
93
+ # Note that we need to access `out['mask']` here `x` does not have that
94
+ # field during the tree map.
95
+ out = jax.tree_map(lambda x: x[out['mask']], out) # pylint: disable=cell-var-from-loop
96
+ count += out['mask'].shape[0]
97
+ out.pop('mask')
98
+ outputs.append(out)
99
+
100
+ logging.log_every_n_seconds(
101
+ logging.INFO, 'Save predictions: processed %i examples so far.', 30,
102
+ count)
103
+
104
+ if jax.process_index():
105
+ return
106
+
107
+ logging.info('Save predictions: processed %d examples.', count)
108
+
109
+ # Actually save in filesystem.
110
+ outputs = jax.tree_map(lambda *x: np.concatenate(x, axis=0), *outputs)
111
+ names_and_vals, _ = u.tree_flatten_with_names(outputs)
112
+ io_buffer = io.BytesIO()
113
+ np.savez_compressed(io_buffer, **{k: v for k, v in names_and_vals})
114
+ with gfile.GFile(self.path, 'wb') as f:
115
+ f.write(io_buffer.getvalue())
116
+ return
117
+
118
+ yield None # pylint: disable=unreachable
Tipsomaly/model/big_vision/evaluators/proj/image_text/contrastive.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Evaluator for the contrastive task.
16
+
17
+ DON'T COMPARE ACROSS RUNS, use for training health monitoring only.
18
+
19
+ Note that this evaluator's `ncorrect_minibatch` is only a rough proxy for
20
+ training progress and does not report the actual `ncorrect`: when the same
21
+ labels found multiple times in a batch, then the reported value is biased
22
+ towards lower values.
23
+
24
+ Also note that the `ncorrect_minibatch` is a function of batch size (it's a lot
25
+ easier to find correct values in small batches).
26
+ """
27
+ import functools
28
+
29
+ from big_vision import input_pipeline
30
+ import big_vision.datasets.core as ds_core
31
+ import big_vision.pp.builder as pp_builder
32
+ import big_vision.utils as u
33
+ import jax
34
+ import jax.numpy as jnp
35
+ import numpy as np
36
+
37
+
38
+ def _all_gather(z):
39
+ """All gather and flatten first two dims."""
40
+ gather_flat = lambda x: jnp.concatenate(jax.lax.all_gather(x, "batch"), 0)
41
+ return jax.tree_map(gather_flat, z)
42
+
43
+
44
+ # To avoid re-compiling the function for every new instance of the same
45
+ # evaluator on a different dataset!
46
+ @functools.lru_cache(None)
47
+ def get_eval_fn(predict_fn, use_global_batch):
48
+ """Produces eval function, also applies pmap."""
49
+
50
+ @functools.partial(jax.pmap, axis_name="batch")
51
+ def _eval_fn(params, images, labels, mask):
52
+ zimg, ztxt, extras = predict_fn(params, images, labels)
53
+
54
+ if use_global_batch:
55
+ zimg, ztxt, mask = _all_gather((zimg, ztxt, mask))
56
+
57
+ # Temperature won't affect ranking for accuracy, but impacts loss magnitude.
58
+ losses, measurements = u.bidirectional_contrastive_loss(
59
+ zimg, ztxt, extras["t"], mask, reduction=False)
60
+ l = jax.lax.psum(losses * mask, axis_name="batch")
61
+ c = jax.lax.psum(measurements["ncorrect"] * mask, axis_name="batch")
62
+ n = jax.lax.psum(mask, axis_name="batch")
63
+ return c, l, n
64
+
65
+ return _eval_fn
66
+
67
+
68
+ class Evaluator:
69
+ """Contrastive evaluator."""
70
+
71
+ def __init__(self, predict_fn, data, pp_fn, batch_size,
72
+ use_global_batch, cache_final=True,
73
+ cache_raw=False, prefetch=1, label_key="labels"):
74
+ data = ds_core.get(**data)
75
+ pp_fn = pp_builder.get_preprocess_fn(pp_fn)
76
+ self.ds, self.steps = input_pipeline.make_for_inference(
77
+ data.get_tfdata(ordered=True), pp_fn, batch_size,
78
+ num_ex_per_process=data.num_examples_per_process(),
79
+ cache_final=cache_final, cache_raw=cache_raw)
80
+ self.data_iter = input_pipeline.start_input_pipeline(self.ds, prefetch)
81
+ self.eval_fn = get_eval_fn(predict_fn, use_global_batch)
82
+ self.label_key = label_key
83
+
84
+ def run(self, params):
85
+ """Computes all metrics."""
86
+ l, c, nseen = 0, 0, 0
87
+ for _, batch in zip(range(self.steps), self.data_iter):
88
+ labels, mask = batch.pop(self.label_key), batch.pop("_mask")
89
+ batch_ncorrect, batch_losses, batch_n = self.eval_fn(
90
+ params, batch["image"], labels, mask)
91
+ # All results are a replicated array shaped as follows:
92
+ # (local_devices, per_device_batch_size, elem_shape...)
93
+ # with each local device's entry being identical as they got psum'd.
94
+ # So let's just take the first one to the host as numpy.
95
+ c += np.sum(np.array(batch_ncorrect[0]))
96
+ l += np.sum(np.array(batch_losses[0]))
97
+ nseen += np.sum(np.array(batch_n[0]))
98
+ yield ("ncorrect_minibatch", c / nseen)
99
+ yield ("loss", l / nseen)
Tipsomaly/model/big_vision/evaluators/proj/image_text/discriminative_classifier.py ADDED
@@ -0,0 +1,440 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Discriminative zero-shot classification evaluator.
16
+ """
17
+
18
+ import functools
19
+ import time
20
+
21
+ from absl import logging
22
+ from big_vision import input_pipeline
23
+ from big_vision import utils
24
+ from big_vision.evaluators.proj.image_text import prompt_engineering
25
+ from big_vision.pp import ops_general # pylint: disable=unused-import
26
+ from big_vision.pp import ops_image # pylint: disable=unused-import
27
+ import big_vision.pp.builder as pp_builder
28
+ import jax
29
+ import jax.numpy as jnp
30
+ from jax.sharding import NamedSharding
31
+ from jax.sharding import PartitionSpec as P
32
+ import numpy as np
33
+ import tensorflow as tf
34
+ import tensorflow_datasets as tfds
35
+
36
+
37
+ # Temporary global flag to facilitate backwards compatability. Will be removed
38
+ # by the end of year 2023.
39
+ API = "jit"
40
+
41
+
42
+ DATASET_NAMES = ("imagenet2012", "cifar100", "oxford_iiit_pet")
43
+ DEFAULT_OVERRIDES = (
44
+ ("imagenet2012", (
45
+ ("class_names", "clip"),
46
+ ("split", "validation"),
47
+ )),
48
+ )
49
+
50
+
51
+ def _with_infinite_padding(dataset):
52
+ """Adds "infinite padding" to the dataset."""
53
+ filler_element = tf.nest.map_structure(
54
+ lambda spec: tf.zeros(spec.shape, spec.dtype)[None], dataset.element_spec)
55
+ filler_element["mask"] = [False]
56
+ filler_dataset = tf.data.Dataset.from_tensor_slices(filler_element)
57
+ dataset = dataset.map(
58
+ lambda features: dict(mask=True, **features),
59
+ num_parallel_calls=tf.data.experimental.AUTOTUNE)
60
+ return dataset.concatenate(filler_dataset.repeat(None))
61
+
62
+
63
+ # This is needed so retrieval_test can replace dataset info.
64
+ def _get_dataset_info(builder):
65
+ return builder.info
66
+
67
+
68
+ def prepare_datasets(img_dataset,
69
+ class_names,
70
+ *,
71
+ prompt_templates,
72
+ pp_img,
73
+ pp_txt,
74
+ cache_final=False,
75
+ pre_filter_fn=None,
76
+ class_name_offset=0):
77
+ """Returns unbatched `ds_images, ds_texts` datasets."""
78
+
79
+ assert prompt_templates, "Must specify prompt templates (e.g. simply ['{}'])"
80
+
81
+ def expand_aliases(idx, class_name):
82
+ class_names = tf.strings.split(class_name, ",")
83
+ return tf.data.Dataset.from_tensor_slices((
84
+ tf.repeat([idx + class_name_offset], len(class_names), axis=0),
85
+ class_names,
86
+ ))
87
+
88
+ def add_prompts(idx, class_name):
89
+ return tf.data.Dataset.from_tensor_slices({
90
+ "label": tf.repeat([idx], len(prompt_templates), axis=0),
91
+ "class_name": tf.repeat([class_name], len(prompt_templates), axis=0),
92
+ "prompt_template": prompt_templates,
93
+ })
94
+
95
+ def substitute_prompt(features):
96
+ parts = tf.strings.split(features["prompt_template"], "{}")
97
+ tf.debugging.assert_equal(len(parts), 2, features["prompt_template"])
98
+ return {
99
+ "label": features["label"],
100
+ "texts": tf.strings.join([parts[0], features["class_name"], parts[1]])
101
+ }
102
+
103
+ if pre_filter_fn:
104
+ img_dataset = img_dataset.filter(pre_filter_fn)
105
+ ds_images = img_dataset.map(
106
+ pp_builder.get_preprocess_fn(f"{pp_img}|keep('label', 'image')"))
107
+ ds_texts = tf.data.Dataset.from_tensor_slices(list(class_names)).enumerate(
108
+ ).flat_map(expand_aliases).flat_map(add_prompts).map(substitute_prompt).map(
109
+ pp_builder.get_preprocess_fn(f"{pp_txt}|keep('label', 'labels')"))
110
+
111
+ if cache_final:
112
+ ds_images, ds_texts = ds_images.cache(), ds_texts.cache()
113
+
114
+ return ds_images, ds_texts
115
+
116
+
117
+ def _split_and_batch(dataset_name, data_dir, class_names, batch_size, split,
118
+ get_ds):
119
+ """Splits dataset, calls `get_ds` and returns padded + batched datasets."""
120
+ assert not batch_size % jax.device_count(), (
121
+ f"batch_size={batch_size} % jax.device_count()={jax.device_count()}")
122
+ builder = tfds.builder(dataset_name, data_dir=data_dir)
123
+
124
+ # Split class names (last process gets remainder).
125
+ if len(class_names) < jax.process_count():
126
+ # See (internal link) for more details.
127
+ class_names += [""] * (jax.process_count() - len(class_names))
128
+ per_process = len(class_names) // jax.process_count()
129
+ class_name_offset = per_process * jax.process_index()
130
+ if jax.process_index() == jax.process_count() - 1:
131
+ class_names = class_names[class_name_offset:]
132
+ else:
133
+ class_names = class_names[class_name_offset:class_name_offset + per_process]
134
+
135
+ ds_images, ds_texts = get_ds(
136
+ builder.as_dataset(split=tfds.split_for_jax_process(split)),
137
+ class_names,
138
+ class_name_offset=class_name_offset)
139
+ return (
140
+ _with_infinite_padding(ds_images).batch(batch_size),
141
+ _with_infinite_padding(ds_texts).batch(batch_size),
142
+ )
143
+
144
+
145
+ def _average_embeddings(embeddings, *, labels, num_classes, normalize):
146
+ """Computes per-class averages of `embeddings`."""
147
+ assert embeddings.ndim == 2, f"Expected {embeddings.ndim}==2"
148
+ assert labels.ndim == 1, f"Expected {labels.ndim}==1"
149
+ assert len(labels) == len(embeddings), (
150
+ f"Expected {len(labels)}=={len(embeddings)}")
151
+
152
+ byidx = [[] for _ in range(num_classes)]
153
+ for label, embedding in zip(labels, embeddings):
154
+ byidx[label].append(embedding)
155
+ missing = set(range(num_classes)) - set(
156
+ idx for idx, embs in enumerate(byidx) if len(embs))
157
+ assert not missing, f"Classes without embeddings: {missing}"
158
+ embeddings = [np.array(embedding).mean(axis=0) for embedding in byidx]
159
+ embeddings = np.stack(embeddings)
160
+
161
+ assert len(embeddings) == num_classes
162
+ if normalize:
163
+ embeddings /= 1e-8 + np.linalg.norm(embeddings, axis=1, keepdims=True)
164
+ return embeddings
165
+
166
+
167
+ class Evaluator:
168
+ """Zero-shot classification evaluator."""
169
+
170
+ def __init__(self,
171
+ predict_fn,
172
+ *,
173
+ batch_size,
174
+ devices,
175
+ dataset_names=DATASET_NAMES,
176
+ data_dir=None,
177
+ class_names="dataset_info:label",
178
+ split="test",
179
+ prompt_templates="clip_paper",
180
+ canonicalize=True,
181
+ pp_img="resize(224)|value_range(-1,1)",
182
+ pp_txt="tokenize(max_len=16, eos='sticky', "
183
+ "pad_value=1, inkey='texts', outkey='labels')",
184
+ cache_final=False,
185
+ pre_filter_fn=None,
186
+ first_class_name_only=True,
187
+ dataset_overrides=DEFAULT_OVERRIDES,
188
+ async_delay=1):
189
+ """Initializes a new zero-shot classification evaluator.
190
+
191
+ See `prepare_datasets()` for details on how the dataset is pre-processed.
192
+
193
+ Args:
194
+ predict_fn: Prediction function with signature
195
+ `zimg, ztxt, out = predict_fn(params, images, texts)`
196
+ batch_size: Global batch size.
197
+ devices: list of devices.
198
+ dataset_names: Names of TFDS datasets to evaluate on.
199
+ data_dir: Optional argument to `tfds.builder()`.
200
+ class_names: Usually specified as a string that is interpreted by
201
+ `prompt_engineering.get_class_names()` to look up class names.
202
+ Alternatively, this attribute can be a list of class names (using ","
203
+ to separate multiple aliases).
204
+ split: Which dataset split to use for evaluation.
205
+ prompt_templates: Specifies which prompt templates to use. See module
206
+ big_vision.evaluators.proj.image_text.prompte_engineering
207
+ for valid values.
208
+ canonicalize: Whether class names and prompt templates should be
209
+ canonicalized. See `prompt_engineering.py` for details.
210
+ pp_img: Preprocessing string for images. Preprocessed features should
211
+ contain key "image" with value that can be batched and is suitable for
212
+ the `images` argument of `predict_fn` input``.
213
+ pp_txt: Preprocessing string for texts. Can expect "texts" key as an input
214
+ (shape=[], dtype=string), and is expected to produce "labels" key that
215
+ is suitable for the `text` argument of `predict_fn` input.
216
+ cache_final: Wether preprocesse dataset should be cached.
217
+ pre_filter_fn: Predicate applied to the dataset for filtering records.
218
+ first_class_name_only: Whether only the first class name should be
219
+ considered (i.e. not using any aliases).
220
+ dataset_overrides: Mapping `dataset_name` to an optional dictionary that
221
+ can override parameters `dataset_name`, `data_dir`, `pp_img`, `pp_txt`,
222
+ `class_names`, `split`, `pre_filter_fn`, and the extra
223
+ `class_names_dataset_name`.
224
+ Works with tuple/dict of tuples/dicts.
225
+ async_delay: How many steps to wait before checking if all hosts have
226
+ finished their batch. A value > 1 allows for more parallelized
227
+ processing, but will results in more unnecessary steps with padded data.
228
+ """
229
+ t0 = time.monotonic()
230
+ self.datasets = {}
231
+ self.prompt_templates = prompt_engineering.get_prompt_templates(
232
+ prompt_templates, canonicalize=canonicalize)
233
+ self._axis_name = "batch"
234
+ dataset_overrides = {k: dict(v) for k, v in dict(dataset_overrides).items()}
235
+
236
+ for dataset_name in dataset_names:
237
+ overrides = dataset_overrides.pop(dataset_name, {})
238
+ dataset_name_ = overrides.pop("dataset_name", dataset_name)
239
+ data_dir_ = overrides.pop("data_dir", data_dir)
240
+ class_names_dataset_name = overrides.pop("class_names_dataset_name",
241
+ dataset_name_)
242
+ class_names_ = overrides.pop("class_names", class_names)
243
+ class_names_ = prompt_engineering.get_class_names(
244
+ dataset_name=class_names_dataset_name,
245
+ source=class_names_,
246
+ canonicalize=canonicalize)
247
+ pp_img_ = overrides.pop("pp_img", pp_img)
248
+ pp_txt_ = overrides.pop("pp_txt", pp_txt)
249
+ cache_final_ = overrides.pop("cache_final", cache_final)
250
+ split_ = overrides.pop("split", split)
251
+ pre_filter_fn_ = overrides.pop("pre_filter_fn", pre_filter_fn)
252
+ prompt_templates_ = overrides.pop("prompt_templates", prompt_templates)
253
+ canonicalize_ = overrides.pop("canonicalize", canonicalize)
254
+ prompt_templates_ = prompt_engineering.get_prompt_templates(
255
+ prompt_templates_, canonicalize=canonicalize_)
256
+ assert not overrides, f"Unknown overrides {dataset_name}: {overrides}"
257
+
258
+ if first_class_name_only:
259
+ class_names_ = [name.split(",")[0] for name in class_names_]
260
+ ds_images, ds_texts = _split_and_batch(
261
+ dataset_name=dataset_name_,
262
+ data_dir=data_dir_,
263
+ class_names=class_names_,
264
+ batch_size=batch_size,
265
+ split=split_,
266
+ get_ds=functools.partial(
267
+ prepare_datasets,
268
+ pp_img=pp_img_,
269
+ pp_txt=pp_txt_,
270
+ cache_final=cache_final_,
271
+ pre_filter_fn=pre_filter_fn_,
272
+ prompt_templates=prompt_templates_))
273
+ self.datasets[dataset_name] = dict(
274
+ images=ds_images, texts=ds_texts, class_names=class_names_,
275
+ dataset_name=dataset_name_, split=split_)
276
+
277
+ assert not dataset_overrides, f"Extra overrides: {dataset_overrides}"
278
+
279
+ def embed_texts(train_state, texts):
280
+ """Returns text embeddings."""
281
+ _, ztxt, _ = predict_fn(train_state, {"labels": texts})
282
+ return ztxt
283
+
284
+ def count_correct(train_state, return_embeddings, *, mask, labels, image,
285
+ ztxt):
286
+ """Returns count of correct predictions (and optionally embeddings)."""
287
+ zimg, _, _ = predict_fn(train_state, {"image": image})
288
+ best_txt = (zimg @ ztxt.T).argmax(axis=1)
289
+ # labels has format [[1, -1, -1], [5, -1, -1], [7, 2, -1], ...]
290
+ # so here we count "any" correct, such that the counting matches the
291
+ # multilabel scenario described in "are we done with imagenet"
292
+ # (http://arxiv.org/abs/2006.07159) section 3.1
293
+ if labels.ndim == 1:
294
+ labels = labels[..., None]
295
+ assert labels.ndim == 2, labels.shape
296
+ matching = (best_txt[:, None] == labels).sum(axis=1)
297
+ correct = jnp.where(mask, (matching > 0).astype(jnp.int32), 0).sum()
298
+ correct = jnp.sum(correct)
299
+ if return_embeddings:
300
+ return correct, zimg
301
+ else:
302
+ return correct, None
303
+
304
+ self.devices = devices
305
+ self.mesh = jax.sharding.Mesh(devices, ("devices",))
306
+
307
+ self._embed_texts_p = jax.jit(
308
+ embed_texts, out_shardings=NamedSharding(self.mesh, P()))
309
+ self._count_correct_p = jax.jit(count_correct, static_argnums=(1,),
310
+ out_shardings=NamedSharding(self.mesh, P()))
311
+ self._count_p = jax.jit(jnp.sum,
312
+ out_shardings=NamedSharding(self.mesh, P()))
313
+ self._all_gather_p = jax.jit(
314
+ lambda x: x, out_shardings=NamedSharding(self.mesh, P()))
315
+
316
+ self._compiled = set()
317
+ assert async_delay > 0, f"async_delay must be >0, not {async_delay}"
318
+ self._async_delay = async_delay
319
+ logging.info("Initialized evaluator in %.1f seconds", time.monotonic() - t0)
320
+
321
+ def _embed_texts(self, train_state, dataset_name):
322
+ """Returns per-class averaged text embeddings."""
323
+ t0 = time.monotonic()
324
+ logging.info("Starting text embedding...")
325
+ ns = []
326
+ embeddings = []
327
+ data = {"label": [], "mask": []}
328
+
329
+ ds_b = input_pipeline.start_global(
330
+ self.datasets[dataset_name]["texts"], self.devices)
331
+ for batch in ds_b:
332
+ ns.append(jax.device_get(self._count_p(batch["mask"])))
333
+ if len(ns) >= self._async_delay and ns[-self._async_delay] == 0:
334
+ break
335
+
336
+ embeddings.append(jax.device_get(self._embed_texts_p(
337
+ train_state, batch["labels"])))
338
+ for name in data:
339
+ data[name].append(jax.device_get(self._all_gather_p(batch[name])))
340
+
341
+ if self._embed_texts_p not in self._compiled:
342
+ logging.info("Compiled text embeddings in %.1fs", time.monotonic() - t0)
343
+ t0 = time.monotonic()
344
+ self._compiled.add(self._embed_texts_p)
345
+
346
+ ns = np.array(ns)
347
+ n = ns.sum()
348
+ data["embedding"] = embeddings
349
+ data = {k: np.concatenate(v, axis=0) for k, v in data.items()}
350
+ mask = data.pop("mask").astype(bool)
351
+ data = {k: v[mask] for k, v in data.items()}
352
+ data["average_embedding"] = _average_embeddings(
353
+ data["embedding"],
354
+ labels=data["label"],
355
+ num_classes=len(self.datasets[dataset_name]["class_names"]),
356
+ normalize=True)
357
+
358
+ logging.info("Embedded %s text in %d steps - ...%s", dataset_name, len(ns),
359
+ ns[-10:])
360
+ logging.info("Totalling %d text in %.1fs", n, time.monotonic() - t0)
361
+ logging.info("Total texts embeddings size %.1fM",
362
+ data["embedding"].nbytes / 1e6)
363
+ return data
364
+
365
+ def evaluate(self,
366
+ train_state,
367
+ dataset_name,
368
+ *,
369
+ return_embeddings=False):
370
+ """Returns evaluation results."""
371
+ texts = self._embed_texts(train_state, dataset_name)
372
+ ztxt_p = texts["average_embedding"]
373
+ ztxt_p = utils.reshard(ztxt_p, NamedSharding(self.mesh, P()))
374
+
375
+ t0 = time.monotonic()
376
+ logging.info("Starting image embedding...")
377
+
378
+ ns = []
379
+ embeddings = []
380
+ corrects = []
381
+ data = {"mask": [], "label": []} if return_embeddings else {}
382
+
383
+ ds_b = input_pipeline.start_global(
384
+ self.datasets[dataset_name]["images"], self.devices)
385
+ for batch in ds_b:
386
+ ns.append(jax.device_get(self._count_p(batch["mask"])))
387
+ if len(ns) >= self._async_delay and ns[-self._async_delay] == 0:
388
+ break
389
+
390
+ labels = batch["label"]
391
+ correct_p, embs_p = self._count_correct_p(
392
+ train_state,
393
+ return_embeddings,
394
+ mask=batch["mask"],
395
+ labels=labels,
396
+ image=batch["image"],
397
+ ztxt=ztxt_p,
398
+ )
399
+ corrects.append(jax.device_get(correct_p))
400
+ if self._count_correct_p not in self._compiled:
401
+ logging.info("Compiled image embeddings in %.1fs",
402
+ time.monotonic() - t0)
403
+ t0 = time.monotonic()
404
+ self._compiled.add(self._count_correct_p)
405
+
406
+ if return_embeddings:
407
+ embeddings.append(jax.device_get(self._all_gather_p(embs_p)))
408
+ for name in data:
409
+ data[name].append(jax.device_get(self._all_gather_p(batch[name])))
410
+
411
+ ns = np.array(ns)
412
+ n = ns.sum()
413
+ correct = np.array(corrects).sum()
414
+
415
+ logging.info("Embedded %s image in %d steps - ...%s", dataset_name, len(ns),
416
+ ns[-10:])
417
+ logging.info("Totalling %d image in %.1fs", n, time.monotonic() - t0)
418
+ ret = {
419
+ "accuracy": correct / n,
420
+ "correct": correct,
421
+ "count": n,
422
+ }
423
+ logging.info("Dataset %s, results %s", dataset_name, ret)
424
+
425
+ if return_embeddings:
426
+ data["embedding"] = embeddings
427
+ data = {k: np.concatenate(v, axis=0) for k, v in data.items()}
428
+ logging.info("Total images embeddings size %.1fM",
429
+ data["embedding"].nbytes / 1e6)
430
+ mask = data.pop("mask").astype(bool)
431
+ ret["images"] = {k: v[mask] for k, v in data.items()}
432
+ ret["texts"] = texts
433
+
434
+ return ret
435
+
436
+ def run(self, train_state):
437
+ """Returns metrics."""
438
+ return [(f"{dataset_name}_accuracy",
439
+ self.evaluate(train_state, dataset_name)["accuracy"])
440
+ for dataset_name in self.datasets]
Tipsomaly/model/big_vision/evaluators/proj/image_text/discriminative_classifier_test.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Tests for discriminative_classifier."""
16
+
17
+ from unittest import mock
18
+
19
+ from big_vision.evaluators.proj.image_text import discriminative_classifier
20
+ from big_vision.pp import ops_general # pylint: disable=unused-import
21
+ from big_vision.pp import ops_image # pylint: disable=unused-import
22
+ from big_vision.pp.registry import Registry
23
+ import flax.linen as nn
24
+ import jax
25
+ import jax.numpy as jnp
26
+ import numpy as np
27
+ import tensorflow as tf
28
+ import tensorflow_datasets as tfds
29
+
30
+
31
+ @Registry.register("preprocess_ops.test_texts2labels")
32
+ def _get_test_texts2labels():
33
+
34
+ def pp(features):
35
+ features["labels"] = tf.strings.to_number(features["texts"])
36
+ return features
37
+
38
+ return pp
39
+
40
+
41
+ @Registry.register("preprocess_ops.copy_from")
42
+ def _get_copy_from(**key_map):
43
+
44
+ def copy_from(d):
45
+ d = dict(d)
46
+ for k1, k2 in key_map.items():
47
+ d[k1] = d[k2]
48
+ return d
49
+
50
+ return copy_from
51
+
52
+
53
+ class _Model(nn.Module):
54
+
55
+ @nn.compact
56
+ def __call__(self, image, texts):
57
+ self.param("x", lambda _: 0.)
58
+
59
+ def z(x):
60
+ if x is not None:
61
+ # Note that the returned vector is most similar with other vectors
62
+ # generated from the same underlying `x[:]`.
63
+ return jnp.stack([jnp.cos(x / 10.), jnp.sin(x / 10.)]).T
64
+
65
+ if texts is not None:
66
+ texts %= 5 # For testing `pre_filter_fn` below.
67
+ return z(image), z(texts), None
68
+
69
+
70
+ class DiscriminativeClassifierTest(tf.test.TestCase):
71
+
72
+ def test_prepare_datasets(self):
73
+
74
+ def generator():
75
+ yield {
76
+ "image": tf.ones([5, 5, 3], tf.float32),
77
+ "label": 1,
78
+ }
79
+ yield {
80
+ "image": tf.ones([4, 4, 3], tf.float32),
81
+ "label": 2,
82
+ }
83
+
84
+ ds = tf.data.Dataset.from_generator(
85
+ generator,
86
+ output_signature={
87
+ "image": tf.TensorSpec(shape=[None, None, 3], dtype=tf.float32),
88
+ "label": tf.TensorSpec(shape=[], dtype=tf.int64),
89
+ })
90
+ class_names = [
91
+ "class1,class1a",
92
+ "class2",
93
+ ]
94
+ prompt_templates = [
95
+ "test {}",
96
+ "test {} test",
97
+ ]
98
+ ds_img, ds_txt = discriminative_classifier.prepare_datasets(
99
+ ds,
100
+ class_names,
101
+ prompt_templates=prompt_templates,
102
+ pp_img="resize(2)",
103
+ pp_txt="copy_from(labels='texts')",
104
+ )
105
+
106
+ it_img = iter(ds_img)
107
+ batch = next(it_img)
108
+ self.assertAllEqual(1, batch["label"])
109
+ self.assertAllEqual(tf.ones([2, 2, 3]), batch["image"])
110
+ batch = next(it_img)
111
+ self.assertAllEqual(2, batch["label"])
112
+ self.assertAllEqual(tf.ones([2, 2, 3]), batch["image"])
113
+
114
+ it_txt = iter(ds_txt)
115
+ batch = next(it_txt)
116
+ self.assertAllEqual(0, batch["label"])
117
+ self.assertAllEqual("test class1", batch["labels"])
118
+ batch = next(it_txt)
119
+ self.assertAllEqual(0, batch["label"])
120
+ self.assertAllEqual("test class1 test", batch["labels"])
121
+ batch = next(it_txt)
122
+ self.assertAllEqual(0, batch["label"])
123
+ self.assertAllEqual("test class1a", batch["labels"])
124
+ batch = next(it_txt)
125
+ self.assertAllEqual(0, batch["label"])
126
+ self.assertAllEqual("test class1a test", batch["labels"])
127
+ batch = next(it_txt)
128
+ self.assertAllEqual(1, batch["label"])
129
+ self.assertAllEqual("test class2", batch["labels"])
130
+ batch = next(it_txt)
131
+ self.assertAllEqual(1, batch["label"])
132
+ self.assertAllEqual("test class2 test", batch["labels"])
133
+
134
+ def test_average_embeddings(self):
135
+ self.assertAllEqual(jnp.array([
136
+ [2.], [4.], [8.],
137
+ ]), discriminative_classifier._average_embeddings(
138
+ embeddings=jnp.array([
139
+ 1., 3., 3., 1., # label1
140
+ 8., 0., # label2
141
+ 32., 0., 0., 0., # label3
142
+ ])[..., None],
143
+ labels=jnp.array([
144
+ 0, 0, # label1
145
+ 0, 0, # label1 (alias)
146
+ 1, 1, # label2
147
+ 2, 2, # label3
148
+ 2, 2, # label3 (alias)
149
+ ], jnp.int32),
150
+ num_classes=3, normalize=False))
151
+ self.assertAllEqual(
152
+ jnp.array([
153
+ [2**-.5, 2**-.5],
154
+ ]),
155
+ discriminative_classifier._average_embeddings(
156
+ embeddings=jnp.array([[2., 2.]]),
157
+ labels=jnp.array([0], jnp.int32),
158
+ num_classes=1,
159
+ normalize=True))
160
+
161
+ @mock.patch("big_vision.evaluators.proj."
162
+ "image_text.prompt_engineering.get_class_names")
163
+ @mock.patch("big_vision.evaluators.proj."
164
+ "image_text.prompt_engineering.get_prompt_templates")
165
+ @mock.patch("big_vision.evaluators.proj."
166
+ "image_text.discriminative_classifier._get_dataset_info")
167
+ def test_evaluate(self, get_dataset_info_mock, get_prompt_templates_mock,
168
+ get_class_names_mock):
169
+ per_device_batch_size = 10 # Make sure we have some unfiltered examples.
170
+ global_batch_size = per_device_batch_size * jax.device_count()
171
+ per_host_num_examples = int(
172
+ np.ceil(global_batch_size / jax.process_count()))
173
+ splits = {
174
+ "test":
175
+ tfds.core.SplitInfo(
176
+ name="test", shard_lengths=[per_host_num_examples], num_bytes=0)
177
+ }
178
+
179
+ model = _Model()
180
+ params = model.init(jax.random.PRNGKey(0), None, None)["params"]
181
+
182
+ prompt_templates = [
183
+ "test prompt 1 {}",
184
+ "test prompt 2 {}",
185
+ ]
186
+ class_names = [
187
+ f"test_class_{i}" for i in range(10)
188
+ ]
189
+
190
+ get_prompt_templates_mock.return_value = prompt_templates
191
+ get_class_names_mock.return_value = class_names
192
+ get_dataset_info_mock.return_value.splits = splits
193
+
194
+ def pre_filter_fn(features):
195
+ return features["label"] < 5 # matches `texts %= 5` above
196
+
197
+ dataset_name = "cifar10_test"
198
+ with tfds.testing.mock_data(num_examples=per_host_num_examples):
199
+ evaluator = discriminative_classifier.Evaluator(
200
+ lambda p, b: model.apply({"params": p},
201
+ b.get("image", None),
202
+ b.get("labels", None)),
203
+ dataset_names=[dataset_name],
204
+ prompt_templates="test_prompts",
205
+ batch_size=global_batch_size,
206
+ devices=jax.devices(),
207
+ pp_img="copy_from(image='label')",
208
+ pp_txt="copy_from(labels='label')",
209
+ dataset_overrides={
210
+ dataset_name: {
211
+ "dataset_name": "cifar10",
212
+ "class_names": "test_classes",
213
+ "pre_filter_fn": pre_filter_fn,
214
+ }
215
+ },
216
+ first_class_name_only=True,
217
+ )
218
+ results = evaluator.evaluate(
219
+ params,
220
+ dataset_name,
221
+ return_embeddings=True)
222
+ metrics = dict(evaluator.run(params))
223
+
224
+ # Assert all examples were processed.
225
+ self.assertLen(results["texts"]["embedding"],
226
+ len(class_names) * len(prompt_templates))
227
+ self.assertLen(results["texts"]["average_embedding"], len(class_names))
228
+ self.assertAllEqual(
229
+ sorted(results["texts"]["label"]),
230
+ [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9])
231
+ # Note that above model makes perfect predictions by design.
232
+ self.assertEqual(1.0, results["accuracy"])
233
+ self.assertEqual(1.0, metrics[f"{dataset_name}_accuracy"])
234
+
235
+
236
+ if __name__ == "__main__":
237
+ tf.test.main()
Tipsomaly/model/big_vision/evaluators/proj/image_text/image_text_retrieval.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Evaluates image-text retrieval results."""
16
+ from typing import List, Mapping
17
+
18
+ import numpy as np
19
+
20
+ RECALL_THRESHOLDS = (1, 5, 10)
21
+
22
+
23
+ def text_to_image_retrieval_eval(
24
+ dist_matrix: np.ndarray,
25
+ text_image_correspondence: List[int]) -> Mapping[str, float]:
26
+ """Runs the text-to-image retrieval eval from the distance matrix.
27
+
28
+ Args:
29
+ dist_matrix: Distance matrix between text and image embeddings (shape
30
+ N_IMAGES x N_TEXTS).
31
+ text_image_correspondence: Mapping between rows and columns of
32
+ `dist_matrix`, that is, a list of N_TEXTS integers n_i that represent that
33
+ the text embedding in column i corresponds to the image embedding in row
34
+ n_i. Please note that many texts can be assigned to the same image. For
35
+ instance, if we have 2 images and 4 texts (i.e. dist_matrix is 2x4), then
36
+ `text_image_correspondence = [0, 0, 1, 1]` means that the two first texts
37
+ correspond to the first image and the two last texts to the second image.
38
+
39
+ Returns:
40
+ A dictionary with the Recall@k scores for k in RECALL_THRESHOLDS.
41
+ """
42
+ per_text_ranks = dist_matrix.argsort(axis=0)
43
+ text_image_correspondence = np.array(text_image_correspondence)
44
+
45
+ def recall_at(k):
46
+ wins = per_text_ranks[:k, :] == text_image_correspondence[None]
47
+ return wins.any(axis=0).mean()
48
+
49
+ return {
50
+ f'Recall@{k}': recall_at(k)
51
+ for k in RECALL_THRESHOLDS
52
+ }
53
+
54
+
55
+ def image_to_text_retrieval_eval(
56
+ dist_matrix: np.ndarray,
57
+ text_image_correspondence: List[int]) -> Mapping[str, float]:
58
+ """Runs the image-to-text retrieval eval from the distance matrix.
59
+
60
+ Args:
61
+ dist_matrix: Distance matrix between text and image embeddings (shape
62
+ N_IMAGES x N_TEXTS).
63
+ text_image_correspondence: Mapping between rows and columns of
64
+ `dist_matrix`, that is, a list of N_TEXTS integers n_i that represent that
65
+ the text embedding in column i corresponds to the image embedding in row
66
+ n_i. Please note that many texts can be assigned to the same image. For
67
+ instance, if we have 2 images and 4 texts (i.e. dist_matrix is 2x4), then
68
+ `text_image_correspondence = [0, 0, 1, 1]` means that the two first texts
69
+ correspond to the first image and the two last texts to the second image.
70
+
71
+ Returns:
72
+ A dictionary with the Recall@k scores for k in RECALL_THRESHOLDS.
73
+ """
74
+ per_image_ranks = dist_matrix.argsort(axis=1)
75
+ text_image_correspondence = np.array(text_image_correspondence)
76
+
77
+ def recall_at(k):
78
+ top_k_images = text_image_correspondence[per_image_ranks[:, :k]]
79
+ wins = top_k_images == np.arange(len(per_image_ranks))[:, None]
80
+ return wins.any(axis=1).mean()
81
+
82
+ return {
83
+ f'Recall@{k}': recall_at(k)
84
+ for k in RECALL_THRESHOLDS
85
+ }
Tipsomaly/model/big_vision/evaluators/proj/image_text/image_text_retrieval_test.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Unit tests for image_text_retrieval."""
16
+ from typing import Mapping
17
+
18
+ from absl.testing import absltest
19
+ from absl.testing import parameterized
20
+ from big_vision.evaluators.proj.image_text import image_text_retrieval
21
+ import numpy as np
22
+
23
+
24
+ class ImTextRetrievalTest(parameterized.TestCase):
25
+
26
+ @parameterized.parameters(
27
+ (np.array([[0.0, 0.0, 0.1, 0.5, 0.1, 0.2, 0.5, 0.1],
28
+ [0.5, 0.4, 0.0, 0.0, 0.4, 0.2, 0.6, 0.4],
29
+ [0.5, 0.4, 0.1, 0.5, 0.0, 0.0, 0.8, 0.3],
30
+ [0.5, 0.4, 0.1, 0.5, 0.3, 0.2, 0.0, 0.0]]), {
31
+ 'Recall@1': 1.0,
32
+ 'Recall@5': 1.0,
33
+ 'Recall@10': 1.0
34
+ }), #
35
+ (np.array([[0.8, 0.8, 0.1, 0.5, 0.1, 0.2, 0.5, 0.1],
36
+ [0.5, 0.4, 0.0, 0.0, 0.4, 0.2, 0.6, 0.4],
37
+ [0.5, 0.4, 0.1, 0.5, 0.0, 0.8, 0.8, 0.3],
38
+ [0.5, 0.4, 0.1, 0.5, 0.4, 0.2, 0.3, 0.3]]), {
39
+ 'Recall@1': 0.5,
40
+ 'Recall@5': 0.75,
41
+ 'Recall@10': 1.0
42
+ }))
43
+ def test_image_to_text_retrieval_eval(self, dist_matrix: np.ndarray,
44
+ expected: Mapping[str, float]):
45
+ """Checks `image_to_text_retrieval_eval`.
46
+
47
+ Args:
48
+ dist_matrix: Distance matrix between image (rows) and text (columns).
49
+ expected: Expected eval results.
50
+ """
51
+ self.assertEqual(
52
+ image_text_retrieval.image_to_text_retrieval_eval(
53
+ dist_matrix, [0, 0, 1, 1, 2, 2, 3, 3]), expected)
54
+
55
+ @parameterized.parameters(
56
+ (np.array([[0.0, 0.0, 0.1, 0.5, 0.1, 0.2, 0.5, 0.1],
57
+ [0.5, 0.4, 0.0, 0.0, 0.4, 0.2, 0.6, 0.4],
58
+ [0.5, 0.4, 0.1, 0.5, 0.0, 0.0, 0.8, 0.3],
59
+ [0.5, 0.4, 0.1, 0.5, 0.3, 0.2, 0.0, 0.0]]), {
60
+ 'Recall@1': 1.0,
61
+ 'Recall@5': 1.0,
62
+ 'Recall@10': 1.0
63
+ }), #
64
+ (np.array([[0.8, 0.8, 0.1, 0.5, 0.1, 0.2, 0.1, 0.1],
65
+ [0.5, 0.4, 0.0, 0.0, 0.4, 0.2, 0.6, 0.4],
66
+ [0.5, 0.4, 0.1, 0.5, 0.0, 0.8, 0.8, 0.3],
67
+ [0.5, 0.4, 0.1, 0.5, 0.4, 0.2, 0.3, 0.3]]), {
68
+ 'Recall@1': 0.375,
69
+ 'Recall@5': 1.0,
70
+ 'Recall@10': 1.0
71
+ }))
72
+ def test_image_text_retrieval(self, dist_matrix: np.ndarray,
73
+ expected: Mapping[str, float]):
74
+ """Checks `text_to_image_retrieval_eval`.
75
+
76
+ Args:
77
+ dist_matrix: Distance matrix between image (rows) and text (columns).
78
+ expected: Expected eval results.
79
+ """
80
+ self.assertEqual(
81
+ image_text_retrieval.text_to_image_retrieval_eval(
82
+ dist_matrix, [0, 0, 1, 1, 2, 2, 3, 3]), expected)
83
+
84
+
85
+ if __name__ == '__main__':
86
+ absltest.main()
Tipsomaly/model/big_vision/evaluators/proj/image_text/prompt_engineering.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Utilities for generating zero-shot prompts."""
16
+
17
+ import re
18
+ import string
19
+ from typing import Sequence
20
+
21
+ from absl import logging
22
+ from big_vision.datasets.imagenet import class_names as imagenet_class_names
23
+ from big_vision.evaluators.proj.image_text import prompt_engineering_constants
24
+ import tensorflow_datasets as tfds
25
+
26
+
27
+ _CLASS_NAMES = { # For each dataset, maps from a source to its class names.
28
+ "imagenet2012": {
29
+ "clip": imagenet_class_names.CLIP_IMAGENET_CLASS_NAMES,
30
+ },
31
+ "grand-vision:imagenet2012": {
32
+ "clip": imagenet_class_names.CLIP_IMAGENET_CLASS_NAMES,
33
+ },
34
+ "imagenet_a": {
35
+ "clip": [
36
+ imagenet_class_names.CLIP_IMAGENET_CLASS_NAMES[i]
37
+ for i in imagenet_class_names.IMAGENET_A_LABELSET
38
+ ]
39
+ },
40
+ "imagenet_r": {
41
+ "clip": [
42
+ imagenet_class_names.CLIP_IMAGENET_CLASS_NAMES[i]
43
+ for i in imagenet_class_names.IMAGENET_R_LABELSET
44
+ ]
45
+ },
46
+ "imagenet_v2": {
47
+ "clip": imagenet_class_names.CLIP_IMAGENET_CLASS_NAMES,
48
+ },
49
+ }
50
+
51
+ _PROMPT_TEMPLATES = {
52
+ "class_name_only": ["{}"],
53
+ "clip_paper": prompt_engineering_constants.CLIP_PAPER_PROMPT_TEMPLATES,
54
+ "clip_best": prompt_engineering_constants.CLIP_BEST_PROMPT_TEMPLATES,
55
+ }
56
+
57
+
58
+ def get_class_names(*, dataset_name, source="dataset_info", canonicalize=True):
59
+ """Returns class name for `dataset_name` from `source`."""
60
+ if isinstance(source, str):
61
+ if source.startswith("dataset_info:"):
62
+ name = source[len("dataset_info:"):]
63
+ class_names = tfds.builder(dataset_name).info.features[name].names
64
+ else:
65
+ class_names = _CLASS_NAMES[dataset_name][source]
66
+ else:
67
+ assert isinstance(source, Sequence) and all(
68
+ map(lambda s: isinstance(s, str), source)), source
69
+ class_names = source
70
+ if canonicalize:
71
+ class_names = [
72
+ canonicalize_text(name, keep_punctuation_exact_string=",")
73
+ for name in class_names
74
+ ]
75
+ logging.info("Using %d class_names: %s", len(class_names), class_names)
76
+ return class_names
77
+
78
+
79
+ def get_prompt_templates(prompt_templates_name,
80
+ *,
81
+ canonicalize=True):
82
+ """Returns prompt templates."""
83
+ prompts_templates = _PROMPT_TEMPLATES[prompt_templates_name]
84
+ if canonicalize:
85
+ prompts_templates = [
86
+ canonicalize_text(name, keep_punctuation_exact_string="{}")
87
+ for name in prompts_templates
88
+ ]
89
+ logging.info("Using %d prompts_templates: %s", len(prompts_templates),
90
+ prompts_templates)
91
+ return prompts_templates
92
+
93
+
94
+ def canonicalize_text(text, *, keep_punctuation_exact_string=None):
95
+ """Returns canonicalized `text` (lowercase and puncuation removed).
96
+
97
+ Args:
98
+ text: string to be canonicalized.
99
+ keep_punctuation_exact_string: If provided, then this exact string kept.
100
+ For example providing '{}' will keep any occurrences of '{}' (but will
101
+ still remove '{' and '}' that appear separately).
102
+ """
103
+ text = text.replace("_", " ")
104
+ if keep_punctuation_exact_string:
105
+ text = keep_punctuation_exact_string.join(
106
+ part.translate(str.maketrans("", "", string.punctuation))
107
+ for part in text.split(keep_punctuation_exact_string))
108
+ else:
109
+ text = text.translate(str.maketrans("", "", string.punctuation))
110
+ text = text.lower()
111
+ text = re.sub(r"\s+", " ", text)
112
+ return text.strip()
Tipsomaly/model/big_vision/evaluators/proj/image_text/prompt_engineering_constants.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Constants used by the module `prompt_engineering` in the same directory."""
16
+
17
+ CLIP_PAPER_PROMPT_TEMPLATES = [
18
+ 'a bad photo of a {}.',
19
+ 'a photo of many {}.',
20
+ 'a sculpture of a {}.',
21
+ 'a photo of the hard to see {}.',
22
+ 'a low resolution photo of the {}.',
23
+ 'a rendering of a {}.',
24
+ 'graffiti of a {}.',
25
+ 'a bad photo of the {}.',
26
+ 'a cropped photo of the {}.',
27
+ 'a tattoo of a {}.',
28
+ 'the embroidered {}.',
29
+ 'a photo of a hard to see {}.',
30
+ 'a bright photo of a {}.',
31
+ 'a photo of a clean {}.',
32
+ 'a photo of a dirty {}.',
33
+ 'a dark photo of the {}.',
34
+ 'a drawing of a {}.',
35
+ 'a photo of my {}.',
36
+ 'the plastic {}.',
37
+ 'a photo of the cool {}.',
38
+ 'a close-up photo of a {}.',
39
+ 'a black and white photo of the {}.',
40
+ 'a painting of the {}.',
41
+ 'a painting of a {}.',
42
+ 'a pixelated photo of the {}.',
43
+ 'a sculpture of the {}.',
44
+ 'a bright photo of the {}.',
45
+ 'a cropped photo of a {}.',
46
+ 'a plastic {}.',
47
+ 'a photo of the dirty {}.',
48
+ 'a jpeg corrupted photo of a {}.',
49
+ 'a blurry photo of the {}.',
50
+ 'a photo of the {}.',
51
+ 'a good photo of the {}.',
52
+ 'a rendering of the {}.',
53
+ 'a {} in a video game.',
54
+ 'a photo of one {}.',
55
+ 'a doodle of a {}.',
56
+ 'a close-up photo of the {}.',
57
+ 'a photo of a {}.',
58
+ 'the origami {}.',
59
+ 'the {} in a video game.',
60
+ 'a sketch of a {}.',
61
+ 'a doodle of the {}.',
62
+ 'a origami {}.',
63
+ 'a low resolution photo of a {}.',
64
+ 'the toy {}.',
65
+ 'a rendition of the {}.',
66
+ 'a photo of the clean {}.',
67
+ 'a photo of a large {}.',
68
+ 'a rendition of a {}.',
69
+ 'a photo of a nice {}.',
70
+ 'a photo of a weird {}.',
71
+ 'a blurry photo of a {}.',
72
+ 'a cartoon {}.',
73
+ 'art of a {}.',
74
+ 'a sketch of the {}.',
75
+ 'a embroidered {}.',
76
+ 'a pixelated photo of a {}.',
77
+ 'itap of the {}.',
78
+ 'a jpeg corrupted photo of the {}.',
79
+ 'a good photo of a {}.',
80
+ 'a plushie {}.',
81
+ 'a photo of the nice {}.',
82
+ 'a photo of the small {}.',
83
+ 'a photo of the weird {}.',
84
+ 'the cartoon {}.',
85
+ 'art of the {}.',
86
+ 'a drawing of the {}.',
87
+ 'a photo of the large {}.',
88
+ 'a black and white photo of a {}.',
89
+ 'the plushie {}.',
90
+ 'a dark photo of a {}.',
91
+ 'itap of a {}.',
92
+ 'graffiti of the {}.',
93
+ 'a toy {}.',
94
+ 'itap of my {}.',
95
+ 'a photo of a cool {}.',
96
+ 'a photo of a small {}.',
97
+ 'a tattoo of the {}.',
98
+ '{}',
99
+ ]
100
+
101
+ CLIP_BEST_PROMPT_TEMPLATES = [
102
+ 'itap of a {}.',
103
+ 'a bad photo of the {}.',
104
+ 'a origami {}.',
105
+ 'a photo of the large {}.',
106
+ 'a {} in a video game.',
107
+ 'art of the {}.',
108
+ 'a photo of the small {}.',
109
+ '{}',
110
+ ]
Tipsomaly/model/big_vision/evaluators/proj/image_text/prompt_engineering_test.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Tests for prompt_engineering."""
16
+
17
+ from absl.testing import absltest
18
+ from big_vision.evaluators.proj.image_text import prompt_engineering
19
+
20
+
21
+ class PromptEngineeringTest(absltest.TestCase):
22
+
23
+ def test_canonicalize_text(self):
24
+ self.assertEqual(prompt_engineering.canonicalize_text("test_test"), "test test")
25
+ self.assertEqual(
26
+ prompt_engineering.canonicalize_text("test___test"), "test test")
27
+ self.assertEqual(prompt_engineering.canonicalize_text("test"), "test")
28
+ self.assertEqual(prompt_engineering.canonicalize_text("test."), "test")
29
+ self.assertEqual(prompt_engineering.canonicalize_text(" test "), "test")
30
+ self.assertEqual(
31
+ prompt_engineering.canonicalize_text("test\ntest"), "test test")
32
+ self.assertEqual(
33
+ prompt_engineering.canonicalize_text("test test"), "test test")
34
+ self.assertEqual(prompt_engineering.canonicalize_text("test {}"), "test")
35
+ self.assertEqual(
36
+ prompt_engineering.canonicalize_text(
37
+ "test {}", keep_punctuation_exact_string="{}"), "test {}")
38
+ self.assertEqual(
39
+ prompt_engineering.canonicalize_text(
40
+ " test {}...", keep_punctuation_exact_string="{}"), "test {}")
41
+ self.assertEqual(
42
+ prompt_engineering.canonicalize_text(
43
+ "test {} {} {}", keep_punctuation_exact_string="{}"),
44
+ "test {} {} {}")
45
+
46
+
47
+ if __name__ == "__main__":
48
+ absltest.main()
Tipsomaly/model/big_vision/evaluators/proj/image_text/retrieval.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Multi-host image->text and text->image retrieval evaluation.
16
+
17
+ Example how to add to config:
18
+
19
+ config.evals {}
20
+ config.evals.retieval = dict(log_steps=1200, type='proj.image_text.retrieval')
21
+ config.evals.retrieval.dataset = 'coco_captions'
22
+ config.evals.retrieval.txt_name = ('captions', 'text')
23
+ # Note that initial "decode|" is not needed.
24
+ config.evals.retrieval.pp_img = 'resize(224)|value_range(-1,1)'
25
+ # Raw text strings use key "texts" in feature dict. The evaluator expects
26
+ # tokenized text with key "labels".
27
+ config.evals.retrieval.pp_txt = (
28
+ 'tokenize(max_len=16, eos="sticky", pad_value=1, inkey="texts", '
29
+ ' outkey="labels")')
30
+
31
+ Example to support precomputed data:
32
+ See `big_vision/configs/proj/image_text/lit.py`.
33
+ """
34
+
35
+ import functools
36
+ import operator
37
+ import time
38
+
39
+ from absl import logging
40
+ from big_vision import input_pipeline
41
+ from big_vision.evaluators.proj.image_text import image_text_retrieval
42
+ import big_vision.pp.builder as pp_builder
43
+ import jax
44
+ import jax.numpy as jnp
45
+ from jax.sharding import NamedSharding
46
+ from jax.sharding import PartitionSpec as P
47
+ import numpy as np
48
+ import tensorflow as tf
49
+ import tensorflow_datasets as tfds
50
+
51
+
52
+ # Temporary global flag to facilitate backwards compatability. Will be removed
53
+ # by the end of year 2023.
54
+ API = "jit"
55
+
56
+
57
+ def _with_infinite_padding(dataset):
58
+ """Adds "infinite padding" to the dataset."""
59
+ filler_element = tf.nest.map_structure(
60
+ lambda spec: tf.zeros(spec.shape, spec.dtype)[None], dataset.element_spec)
61
+ filler_element["mask"] = [False]
62
+ filler_dataset = tf.data.Dataset.from_tensor_slices(filler_element)
63
+ dataset = dataset.map(
64
+ lambda features: dict(mask=True, **features),
65
+ num_parallel_calls=tf.data.experimental.AUTOTUNE)
66
+ return dataset.concatenate(filler_dataset.repeat(None))
67
+
68
+
69
+ # This is needed so retrieval_test can replace dataset info.
70
+ def _get_dataset_info(builder):
71
+ return builder.info
72
+
73
+
74
+ def prepare_datasets(
75
+ dataset, *, pp_img, pp_txt, txt_name, offset=0, cache_final=False
76
+ ):
77
+ """Returns unbatched `ds_images, ds_texts` datasets.
78
+
79
+ Args:
80
+ dataset: An image-text `tf.data.Dataset` that is expected to contain the
81
+ following features: "image" (dtype=uint8, shape=[None, None, 3]),
82
+ `txt_name` (dtype=string, shape=[None]).
83
+ pp_img: String defining pre-processing for images. The pre-processing can
84
+ expect the following features to be prepared: "image", "id". The
85
+ pre-processing should convert the "image" (dtype=uint8,
86
+ shape=[None, None, 3]) to "image" (dtype=float32, shape=[sz, sz, 3]).
87
+ pp_txt: String defining pre-processing for text. The pre-processing can
88
+ expect the following features to be prepared: "texts", "id", "caption_id".
89
+ The pre-processing should convert the "texts" (dtype=string, shape=[])
90
+ into a tokenized "labels" (dtype=int32, shape=[max_len]).
91
+ txt_name: Name of the text feature to unroll in the original `dataset`. Can
92
+ be a simple string feature name, or an iterable of strings to specify a
93
+ nested feature (e.g. for "coco_captions", this would be
94
+ `('captions', 'text')`).
95
+ offset: Offset that should be added to enumerated examples to generate IDs.
96
+ In a multi-host setup, this is typically set to a value large enough to
97
+ make all IDs distinct.
98
+ cache_final: Whether the dataset should be cached.
99
+
100
+ Returns:
101
+ Image and text datasets.
102
+ """
103
+
104
+ def get_feature_value(data, feature_name):
105
+ if isinstance(feature_name, str):
106
+ feature_name = [feature_name]
107
+ return functools.reduce(operator.getitem, feature_name, data)
108
+
109
+ def get_captions(idx, features):
110
+ """Returns a dataset with unrolled "caption" for every example."""
111
+ texts = get_feature_value(features, txt_name)
112
+ texts = tf.experimental.numpy.atleast_1d(texts) # For single-text GT.
113
+ texts_n = tf.shape(texts)[0]
114
+ return tf.data.Dataset.from_tensor_slices({
115
+ "id": tf.tile([idx + offset], [texts_n]),
116
+ "caption_i": tf.stack(tf.range(texts_n)),
117
+ "texts": tf.stack(texts),
118
+ })
119
+
120
+ def add_id(idx, features):
121
+ return {**features, "id": idx + offset}
122
+
123
+ ds_images = dataset.enumerate().map(add_id).map(
124
+ pp_builder.get_preprocess_fn(f"{pp_img}|keep('id', 'image')"))
125
+ ds_texts = dataset.enumerate().flat_map(get_captions).map(
126
+ pp_builder.get_preprocess_fn(
127
+ f"{pp_txt}|keep('id', 'caption_i', 'labels')"))
128
+ if cache_final:
129
+ ds_images, ds_texts = ds_images.cache(), ds_texts.cache()
130
+ return ds_images, ds_texts
131
+
132
+
133
+ def _split_and_batch(dataset_name, batch_size, split, get_ds, data_dir=None):
134
+ """Splits dataset, calls `get_ds` and returns padded + batched datasets."""
135
+ assert not batch_size % jax.device_count(), (
136
+ f"batch_size={batch_size} % jax.device_count()={jax.device_count()}")
137
+ builder = tfds.builder(dataset_name, data_dir=data_dir)
138
+ info = _get_dataset_info(builder)
139
+ num_examples = info.splits[split].num_examples
140
+ ds_images, ds_texts = get_ds(
141
+ builder.as_dataset(split=tfds.split_for_jax_process(split)),
142
+ offset=jax.process_index() * num_examples,
143
+ )
144
+ return (
145
+ _with_infinite_padding(ds_images).batch(batch_size),
146
+ _with_infinite_padding(ds_texts).batch(batch_size),
147
+ )
148
+
149
+
150
+ class Evaluator:
151
+ """Image/text retrieval evaluator."""
152
+
153
+ def __init__(self,
154
+ predict_fn,
155
+ *,
156
+ dataset,
157
+ pp_img,
158
+ pp_txt,
159
+ txt_name,
160
+ batch_size,
161
+ devices,
162
+ data_dir=None,
163
+ split="test",
164
+ cache_final=True):
165
+ """Initializes a new zero-shot image/text retrieval evaluator.
166
+
167
+ See `prepare_datasets()` for details on how the dataset is pre-processed.
168
+
169
+ Args:
170
+ predict_fn: Prediction function with signature
171
+ `zimg, ztxt, out = predict_fn(params, images, texts)`
172
+ dataset: The TFDS dataset name of the eval data.
173
+ pp_img: Preprocessing string for images. Preprocessed features should
174
+ contain key "image" with value that can be batched and is suitable for
175
+ `predict_fn(images)` input``.
176
+ pp_txt: Preprocessing string for texts. Can expect "texts" key as an input
177
+ (shape=[], dtype=string), and is expected to produce "labels" key that
178
+ is suitable for `predict_fn(texts)` input.
179
+ txt_name: The name of the feature of captions (can be a tuple to look up a
180
+ value in a nested feature dictionary). Expected shape=[None],
181
+ dtype=string. specified then items are used as lookup path.
182
+ batch_size: Global batch size.
183
+ devices: list of devices.
184
+ data_dir: Optional dir to load the TFDS dataset from.
185
+ split: The split of the eval data.
186
+ cache_final: Wether preprocessed dataset should be cached.
187
+ """
188
+ self.ds_images, self.ds_texts = _split_and_batch(
189
+ dataset,
190
+ batch_size,
191
+ split,
192
+ functools.partial(
193
+ prepare_datasets,
194
+ pp_img=pp_img,
195
+ pp_txt=pp_txt,
196
+ txt_name=txt_name,
197
+ cache_final=cache_final,
198
+ ),
199
+ data_dir=data_dir,
200
+ )
201
+ self._axis_name = "batch"
202
+
203
+ self.devices = devices
204
+ mesh = jax.sharding.Mesh(devices, ("devices",))
205
+
206
+ def embed_images(train_state, images):
207
+ zimg, _, _ = predict_fn(train_state, {"image": images})
208
+ return zimg
209
+
210
+ def embed_texts(train_state, texts):
211
+ _, ztxt, _ = predict_fn(train_state, {"labels": texts})
212
+ return ztxt
213
+
214
+ self._embed_images_p = jax.jit(embed_images,
215
+ out_shardings=NamedSharding(mesh, P()))
216
+ self._embed_texts_p = jax.jit(embed_texts,
217
+ out_shardings=NamedSharding(mesh, P()))
218
+ self._all_gather_p = jax.jit(
219
+ lambda x: x, out_shardings=NamedSharding(mesh, P()))
220
+ self._count_p = jax.jit(jnp.sum, out_shardings=NamedSharding(mesh, P()))
221
+ self._compiled = set()
222
+
223
+ def _embed(self, name, train_state, ds, embed_fn, id_names):
224
+ """Embeds features name `name` using `embed_fn`.
225
+
226
+ Args:
227
+ name: Feature name to be embedded.
228
+ train_state: train_state for the predict_fn.
229
+ ds: The dataset.
230
+ embed_fn: A pmapped function that returns the embeddings.
231
+ id_names: An iterable of feature names that should be collected.
232
+
233
+ Returns:
234
+ A dictionary with "embeddings" and `id_names` as keys.
235
+ """
236
+ ns = []
237
+ embeddings = []
238
+ ids = {id_name: [] for id_name in list(id_names) + ["mask"]}
239
+
240
+ t0 = time.time()
241
+
242
+ ds_b = input_pipeline.start_global(ds, self.devices)
243
+ for batch in ds_b:
244
+ ns.append(jax.device_get(self._count_p(batch["mask"])))
245
+
246
+ # Due to infinite padding, this loop will never end. We will stop once
247
+ # all processes only process padded data. We don't check the latest
248
+ # DeviceArray `ns[-1]` Because we want to keep our computation async for
249
+ # efficiency reasons.
250
+ if len(ns) >= 2 and ns[-2] == 0:
251
+ break
252
+
253
+ embs = embed_fn(train_state, batch[name])
254
+ if embed_fn not in self._compiled:
255
+ logging.info("Compiled %s embeddings in %.3fs", name, time.time() - t0)
256
+ t0 = time.time()
257
+ self._compiled.add(embed_fn)
258
+
259
+ embeddings.append(jax.device_get(embs))
260
+ for id_name in ids:
261
+ ids[id_name].append(jax.device_get(self._all_gather_p(batch[id_name])))
262
+
263
+ # Only access DeviceArray at end of loop for better efficiency.
264
+ ns = np.array(ns)
265
+ embeddings = np.concatenate(embeddings)
266
+ ids = {k: np.concatenate(v) for k, v in ids.items()}
267
+ masks = ids.pop("mask").astype(bool)
268
+ logging.info("Processed %s in %d steps - ...%s", name, len(ns), ns[-10:])
269
+ n = ns.sum()
270
+ logging.info("Totalling %d %s in %.3fs", n, name, time.time() - t0)
271
+ return {
272
+ "embeddings": embeddings[masks],
273
+ **{k: v[masks] for k, v in ids.items()},
274
+ }
275
+
276
+ def evaluate(self, train_state):
277
+ """Returns evaluation results."""
278
+ images = self._embed("image", train_state, self.ds_images,
279
+ self._embed_images_p, ("id",))
280
+ texts = self._embed("labels", train_state, self.ds_texts,
281
+ self._embed_texts_p, ("id", "caption_i"))
282
+ # Shapes: (nimg, emb) * (emb, ntxt) -> (nimg, ntxt)
283
+ similarities = np.dot(images["embeddings"], texts["embeddings"].T)
284
+
285
+ t0 = time.time()
286
+ id2img = {id_: i for i, id_ in enumerate(images["id"])}
287
+ text_image_correspondence = [id2img[id_] for id_ in texts["id"]]
288
+ img2txt = image_text_retrieval.image_to_text_retrieval_eval(
289
+ -similarities, text_image_correspondence)
290
+ txt2img = image_text_retrieval.text_to_image_retrieval_eval(
291
+ -similarities, text_image_correspondence)
292
+ logging.info("Computed retrieval metrics in %.3fs", time.time() - t0)
293
+
294
+ return dict(
295
+ images=images,
296
+ texts=texts,
297
+ img2txt=img2txt,
298
+ txt2img=txt2img,
299
+ )
300
+
301
+ def run(self, train_state):
302
+ """Returns metrics."""
303
+ results = self.evaluate(train_state)
304
+ return [(f"{direction}_{k.lower()}", v)
305
+ for direction in ("img2txt", "txt2img")
306
+ for k, v in results[direction].items()]
Tipsomaly/model/big_vision/evaluators/proj/image_text/retrieval_test.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Tests for retrieval."""
16
+
17
+ from unittest import mock
18
+
19
+ from big_vision.evaluators.proj.image_text import retrieval
20
+ from big_vision.pp import ops_general # pylint: disable=unused-import
21
+ from big_vision.pp import ops_image # pylint: disable=unused-import
22
+ from big_vision.pp import registry
23
+ import chex
24
+ import flax.linen as nn
25
+ import jax
26
+ import jax.numpy as jnp
27
+ import tensorflow as tf
28
+ import tensorflow_datasets as tfds
29
+
30
+
31
+ def _get_test_texts2labels():
32
+
33
+ def pp(features):
34
+ features["labels"] = tf.strings.to_number(features["texts"])
35
+ return features
36
+
37
+ return pp
38
+
39
+
40
+ def _get_copy_from(**key_map):
41
+
42
+ def copy_from(d):
43
+ d = dict(d)
44
+ for k1, k2 in key_map.items():
45
+ d[k1] = d[k2]
46
+ return d
47
+
48
+ return copy_from
49
+
50
+
51
+ class _Model(nn.Module):
52
+
53
+ @nn.compact
54
+ def __call__(self, image, texts):
55
+ self.param("x", lambda _: 0.)
56
+
57
+ def z(x):
58
+ if x is not None:
59
+ batch_size = len(x)
60
+ # Note that the returned vector is most similar with other vectors
61
+ # generated from the same underlying `x[:]`.
62
+ x = jnp.concatenate([100 * jnp.ones([batch_size, 1]), x[:, None]],
63
+ axis=1)
64
+ return x / jnp.linalg.norm(x, axis=1)[:, None]
65
+
66
+ return z(image), z(texts), None
67
+
68
+
69
+ def setUpModule():
70
+ chex.set_n_cpu_devices(8)
71
+
72
+
73
+ class RetrievalTest(tf.test.TestCase):
74
+
75
+ def test_prepare_datasets(self):
76
+
77
+ def generator():
78
+ yield {
79
+ "image": tf.ones([5, 5, 3], tf.float32),
80
+ "captions": {
81
+ "text": tf.constant(["11", "12"])
82
+ }
83
+ }
84
+ yield {
85
+ "image": tf.ones([4, 4, 3], tf.float32),
86
+ "captions": {
87
+ "text": tf.constant(["21", "22", "23"])
88
+ }
89
+ }
90
+
91
+ ds = tf.data.Dataset.from_generator(
92
+ generator,
93
+ output_signature={
94
+ "image": tf.TensorSpec(shape=[None, None, 3], dtype=tf.float32),
95
+ "captions": {
96
+ "text": tf.TensorSpec(shape=[None], dtype=tf.string),
97
+ },
98
+ })
99
+ with registry.temporary_ops(test_texts2labels=_get_test_texts2labels):
100
+ ds_img, ds_txt = retrieval.prepare_datasets(
101
+ ds,
102
+ pp_img="resize(2)",
103
+ pp_txt="test_texts2labels()",
104
+ txt_name=("captions", "text"),
105
+ )
106
+ it_img = iter(ds_img)
107
+ it_txt = iter(ds_txt)
108
+ batch = next(it_img)
109
+ self.assertAllEqual(batch["id"], 0)
110
+ self.assertAllEqual(batch["image"], tf.ones([2, 2, 3]))
111
+ batch = next(it_img)
112
+ self.assertAllEqual(batch["id"], 1)
113
+ self.assertAllEqual(batch["image"], tf.ones([2, 2, 3]))
114
+ batch = next(it_txt)
115
+ self.assertAllEqual(batch["id"], 0)
116
+ self.assertAllEqual(batch["caption_i"], 0)
117
+ self.assertAllEqual(batch["labels"], 11.0)
118
+ batch = next(it_txt)
119
+ self.assertAllEqual(batch["id"], 0)
120
+ self.assertAllEqual(batch["caption_i"], 1)
121
+ self.assertAllEqual(batch["labels"], 12.0)
122
+ batch = next(it_txt)
123
+ self.assertAllEqual(batch["id"], 1)
124
+ self.assertAllEqual(batch["caption_i"], 0)
125
+ self.assertAllEqual(batch["labels"], 21.0)
126
+ batch = next(it_txt)
127
+ self.assertAllEqual(batch["id"], 1)
128
+ self.assertAllEqual(batch["caption_i"], 1)
129
+ self.assertAllEqual(batch["labels"], 22.0)
130
+ batch = next(it_txt)
131
+ self.assertAllEqual(batch["id"], 1)
132
+ self.assertAllEqual(batch["caption_i"], 2)
133
+ self.assertAllEqual(batch["labels"], 23.0)
134
+
135
+ def test_evaluate(self):
136
+ per_device_batch_size = 2
137
+ batch_size = per_device_batch_size * jax.device_count()
138
+ num_examples = 1 * batch_size + 1
139
+ splits = {
140
+ "test":
141
+ tfds.core.SplitInfo(
142
+ name="test", shard_lengths=[num_examples], num_bytes=0)
143
+ }
144
+
145
+ model = _Model()
146
+ params = model.init(jax.random.PRNGKey(0), None, None)["params"]
147
+
148
+ with tfds.testing.mock_data(num_examples=num_examples):
149
+ info_mock = mock.Mock()
150
+ info_mock.splits = splits
151
+ with mock.patch.object(retrieval, "_get_dataset_info",
152
+ lambda _: info_mock):
153
+ with registry.temporary_ops(copy_from=_get_copy_from):
154
+ evaluator = retrieval.Evaluator(
155
+ lambda p, b: model.apply({"params": p},
156
+ b.get("image", None),
157
+ b.get("labels", None)),
158
+ dataset="coco_captions",
159
+ batch_size=batch_size,
160
+ devices=jax.devices(),
161
+ txt_name=("captions", "text"),
162
+ pp_img="copy_from(image='id')",
163
+ pp_txt="copy_from(labels='id')",
164
+ )
165
+ results = evaluator.evaluate(params)
166
+
167
+ # Assert all examples were processed.
168
+ self.assertLen(results["images"]["embeddings"], num_examples)
169
+ self.assertLen(results["images"]["id"], num_examples)
170
+ # Assert no padding was processed (expects exactly one (=first) image.id=0
171
+ self.assertEqual((results["images"]["id"] == 0).sum(), 1)
172
+ # Expect perfect ITR with above _Model()...
173
+ self.assertEqual(results["img2txt"]["Recall@1"], 1.0)
174
+ self.assertEqual(results["txt2img"]["Recall@5"], 1.0)
175
+
176
+
177
+ if __name__ == "__main__":
178
+ tf.test.main()
Tipsomaly/model/big_vision/evaluators/proj/paligemma/perplexity.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Evaluator for perplexity of a model."""
16
+ import functools
17
+
18
+ from big_vision.evaluators import mean
19
+ import big_vision.utils as u
20
+ import jax.numpy as jnp
21
+
22
+
23
+ # Temporary global flag to facilitate backwards compatability. Will be removed
24
+ # by the end of year 2023.
25
+ API = 'jit'
26
+
27
+
28
+ # Cache the function such that it won't always recompile (in mean evaluator).
29
+ @functools.cache
30
+ def perplexity(
31
+ predict_fn, key='labels', shift_labels=True, pad_token=None):
32
+ """Returns a function that computes perplexity."""
33
+
34
+ def _perplexity_fn(train_state, batch, **kw):
35
+ logits, _ = predict_fn(train_state, batch, **kw)
36
+
37
+ labels = batch[key]
38
+ weights = batch.get('mask_loss', jnp.ones_like(labels))
39
+
40
+ if pad_token is not None:
41
+ weights = weights * (labels != pad_token).astype(jnp.float32)
42
+
43
+ if shift_labels:
44
+ labels = labels[:, 1:]
45
+ weights = weights[:, 1:]
46
+
47
+ losses = u.weighted_softmax_xent(
48
+ logits=logits, labels=labels, weights=weights,
49
+ reduction=False, normalize=False)
50
+ normalizer = jnp.clip(weights.sum(axis=1), 2e-38)
51
+
52
+ return {'sum': losses, 'avg': losses / normalizer}
53
+ return _perplexity_fn
54
+
55
+
56
+ class Evaluator(mean.Evaluator):
57
+ """Perplexity evaluator."""
58
+
59
+ def __init__(self, predict_fn, *a,
60
+ key='labels', shift_labels=False, pad_token=None, **kw):
61
+ kw.setdefault('prefetch', 0) # More memory-saving default.
62
+ super().__init__(
63
+ perplexity(predict_fn, key, shift_labels, pad_token), *a, **kw)
Tipsomaly/model/big_vision/evaluators/proj/paligemma/transfers/chartqa.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Evaluator for ChartQA variants."""
16
+
17
+ import functools
18
+
19
+ import big_vision.evaluators.common as c
20
+ import big_vision.pp.tokenizer
21
+ import big_vision.utils as u
22
+
23
+
24
+ # Temporary global flag to facilitate backwards compatability. Will be removed
25
+ # by the end of year 2023.
26
+ API = "jit"
27
+
28
+
29
+ class Evaluator:
30
+ """Evaluator for simple VQA tasks."""
31
+
32
+ def __init__(
33
+ self, predict_fn, tokenizer, to_lower=False,
34
+ outfile="{workdir}/{split}.json",
35
+ out_question_key="question_id", out_answer_key="answer",
36
+ *, data, devices, **kw):
37
+ self.get_data_iter, self.steps = c.eval_input_pipeline(
38
+ keep_on_cpu={"answer", "question_id"}, data=data, devices=devices, **kw)
39
+
40
+ self.outfile = c.resolve_outfile(outfile, split=data.get("split"))
41
+ self.out_question_key = out_question_key
42
+ self.out_answer_key = out_answer_key
43
+
44
+ # We'll need the tokenizer to detokenize the model outputs later.
45
+ self.tok = big_vision.pp.tokenizer.get_tokenizer(tokenizer)
46
+ self.postproc = (lambda s: s.lower()) if to_lower else lambda s: s
47
+ self.decode = functools.partial(
48
+ predict_fn, devices=devices, eos_token=self.tok.eos_token)
49
+
50
+ def run(self, train_state):
51
+ """Does one evaluation run, yields metrics."""
52
+
53
+ accuracies = []
54
+ relaxed_accuracies = []
55
+ json_out = []
56
+ for _, batch in zip(range(self.steps), self.get_data_iter()):
57
+ # (batch, seqlen) array of decoded generated tokens.
58
+ tokens = self.decode(train_state, batch)
59
+
60
+ # (local_batch,) that indicates padding examples (0) vs real examples (1).
61
+ tokens = u.get_local_slice_from_fsarray(tokens)
62
+ ex_masks = u.get_local_slice_from_fsarray(batch["_mask"])
63
+
64
+ # Turn predictions into texts and then scores, one by one.
65
+ for i in range(len(tokens)):
66
+ if ex_masks[i] == 0: # Skip last-batch padding examples
67
+ continue
68
+
69
+ answer = self.postproc(self.tok.to_str(tokens[i], stop_at_eos=True))
70
+
71
+ gt = self.postproc(batch["answer"][i])
72
+ accuracies.append(float(answer == gt))
73
+ relaxed_accuracies.append(_relaxed_match(gt, answer))
74
+ json_out.append({
75
+ self.out_question_key: batch["question_id"][i].item(),
76
+ self.out_answer_key: answer,
77
+ "gt": gt,
78
+ "relaxed_match": relaxed_accuracies[-1],
79
+ })
80
+
81
+ # At this point `accuracies` is a list of per-example scores. However,
82
+ # remember that each host holds a different subset of the examples! So if
83
+ # we were to just return the mean accuracy here, we would effectively only
84
+ # have evaluated on the main host's (who writes metrics) subset!
85
+ # So now, we need to compute global means.
86
+ # There is one more caveat: `process_sum` needs the summands on each host
87
+ # to have the same size. So we either need to include dummy values for
88
+ # the padding examples (last batch, annoying), or we only sum scalars as in
89
+ # sufficient statistics, which we do here.
90
+ sum_accs, sum_relaxed_accs, num = c.process_sum(
91
+ [sum(accuracies), sum(relaxed_accuracies), len(accuracies)])
92
+
93
+ # Yielding metric_name, value means logging the metric.
94
+ yield "acc", sum_accs / num
95
+ yield "relaxed_acc", sum_relaxed_accs / num
96
+ yield "num", num # Just for sanity checks.
97
+ c.multiprocess_write_json(self.outfile, json_out)
98
+
99
+
100
+ def _to_float(text: str) -> float | None:
101
+ try:
102
+ if text.endswith("%"):
103
+ # Convert percentages to floats.
104
+ return float(text.rstrip("%")) / 100.0
105
+ else:
106
+ return float(text)
107
+ except ValueError:
108
+ return None
109
+
110
+
111
+ def _relaxed_match(target: str,
112
+ prediction: str,
113
+ max_relative_error: float = 0.05) -> bool:
114
+ """Calculates relaxed correctness.
115
+
116
+ The correctness tolerates certain error ratio defined by max_relative_error.
117
+ See https://arxiv.org/pdf/2203.10244.pdf, end of section 5.1:
118
+ “Following Methani et al. (2020), we use a relaxed accuracy measure for the
119
+ numeric answers to allow a minor inaccuracy that may result from the automatic
120
+ data extraction process. We consider an answer to be correct if it is within
121
+ 5% of the gold answer. For non-numeric answers, we still need an exact match
122
+ to consider an answer to be correct.”
123
+
124
+ Args:
125
+ target: Target string.
126
+ prediction: Predicted string.
127
+ max_relative_error: Maximum relative error.
128
+
129
+ Returns:
130
+ Whether the prediction was correct given the specified tolerance.
131
+ """
132
+ prediction_float = _to_float(prediction)
133
+ target_float = _to_float(target)
134
+ # When the target is 0 is always required an exact match.
135
+ if prediction_float is not None and target_float:
136
+ relative_error = abs(prediction_float - target_float) / abs(target_float)
137
+ return relative_error <= max_relative_error
138
+ else:
139
+ return prediction == target
Tipsomaly/model/big_vision/evaluators/proj/paligemma/transfers/pope.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Evaluator for the POPE dataset (https://github.com/RUCAIBox/POPE).
16
+
17
+ POPE is a binary classification dataset with ground-truth answers being either
18
+ 'yes' or 'no'.
19
+ """
20
+
21
+ import functools
22
+
23
+ import big_vision.datasets.core
24
+ import big_vision.evaluators.common as c
25
+ import big_vision.input_pipeline
26
+ import big_vision.pp.builder
27
+ import big_vision.pp.tokenizer
28
+ import big_vision.utils as u
29
+
30
+
31
+ # Temporary global flag to facilitate backwards compatability. Will be removed
32
+ # by the end of year 2023.
33
+ API = "jit"
34
+
35
+
36
+ class Evaluator:
37
+ """Evaluator for the POPE task.
38
+
39
+ This evaluator expects the batch to contain a field `question_id` and a field
40
+ `answer` for single ground truth or `answers` for multiple ground truths.
41
+
42
+ The field names used when writting the json result can be controlled with
43
+ `out_question_key` and `out_answer_key`.
44
+ """
45
+
46
+ def __init__(
47
+ self,
48
+ predict_fn,
49
+ data,
50
+ pp_fn,
51
+ tokenizer,
52
+ batch_size,
53
+ *,
54
+ devices,
55
+ outfile="{workdir}/{split}.json",
56
+ out_question_key="question_id",
57
+ out_answer_key="answer"
58
+ ):
59
+
60
+ self.outfile = c.resolve_outfile(outfile, split=data.get("split"))
61
+ self.out_question_key = out_question_key
62
+ self.out_answer_key = out_answer_key
63
+ # This will mostly look the same across all evaluators, preparing data:
64
+ data = big_vision.datasets.core.get(**data)
65
+ pp_fn = big_vision.pp.builder.get_preprocess_fn(pp_fn)
66
+ self.ds, self.steps = big_vision.input_pipeline.make_for_inference(
67
+ data.get_tfdata(ordered=True),
68
+ pp_fn,
69
+ batch_size,
70
+ num_ex_per_process=data.num_examples_per_process(),
71
+ )
72
+ # The `keep_on_cpu=` argument lists the data keys that, if they exist, we
73
+ # do NOT want to ship to the TPUs and instead just keep in host memory.
74
+ # Typically ground-truth and metadata, that is often of string type.
75
+ self.data_iter = big_vision.input_pipeline.start_global(
76
+ self.ds, devices, keep_on_cpu={"answer", "question_id"}
77
+ )
78
+ # We'll need the tokenizer to detokenize the model outputs later.
79
+ self.tok = big_vision.pp.tokenizer.get_tokenizer(tokenizer)
80
+ self.decode = functools.partial(
81
+ predict_fn, devices=devices, eos_token=self.tok.eos_token
82
+ )
83
+
84
+ def run(self, train_state):
85
+ """Does one evaluation run, yields metrics."""
86
+
87
+ accuracies = []
88
+ valid = []
89
+ json_out = []
90
+ for _, batch in zip(range(self.steps), self.data_iter):
91
+ # (batch, seqlen) array of decoded generated tokens.
92
+ tokens = self.decode(train_state, batch)
93
+
94
+ # (local_batch,) that indicates padding examples (0) vs real examples (1).
95
+ tokens = u.get_local_slice_from_fsarray(tokens)
96
+ ex_masks = u.get_local_slice_from_fsarray(batch["_mask"])
97
+
98
+ # Turn predictions into texts and then scores, one by one.
99
+ for i in range(len(tokens)):
100
+ if ex_masks[i] == 0: # Skip last-batch padding examples
101
+ continue
102
+
103
+ answer = self.tok.to_str(tokens[i], stop_at_eos=True).lower()
104
+ gt = batch["answer"][i]
105
+ accuracies.append(float(answer == gt))
106
+ valid.append(float(answer in ("yes", "no")))
107
+
108
+ json_out.append(
109
+ {
110
+ self.out_question_key: batch["question_id"][i].item(),
111
+ self.out_answer_key: answer,
112
+ }
113
+ )
114
+
115
+ # At this point `accuracies` is a list of per-example scores. However,
116
+ # remember that each host holds a different subset of the examples! So if
117
+ # we were to just return the mean accuracy here, we would effectively only
118
+ # have evaluated on the main host's (who writes metrics) subset!
119
+ # So now, we need to compute global means.
120
+ # There is one more caveat: `process_sum` needs the summands on each host
121
+ # to have the same size. So we either need to include dummy values for
122
+ # the padding examples (last batch, annoying), or we only sum scalars as in
123
+ # sufficient statistics, which we do here.
124
+ sum_accs, sum_valid, num = c.process_sum([
125
+ sum(accuracies),
126
+ sum(valid),
127
+ len(accuracies),
128
+ ])
129
+
130
+ if num:
131
+ yield "acc", sum_accs / num
132
+ yield "valid_percent", sum_valid / num
133
+ yield "num", num
134
+
135
+ c.multiprocess_write_json(self.outfile, json_out)
Tipsomaly/model/big_vision/evaluators/proj/uvim/coco_panoptic.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """COCO17 panoptic evaluation."""
16
+ import functools
17
+ from functools import partial
18
+ import json
19
+ import os
20
+ import tempfile
21
+ import time
22
+ import zipfile
23
+
24
+ from absl import logging
25
+ from big_vision.evaluators.proj.uvim import common
26
+ import big_vision.pp.builder as pp_builder
27
+ import jax
28
+ import numpy as np
29
+ import panopticapi_converters.twochannels2panoptic_coco_format as converter
30
+ from panopticapi.evaluation import pq_compute
31
+ import tensorflow as tf
32
+ import tensorflow_datasets as tfds
33
+
34
+ from tensorflow.io import gfile
35
+
36
+
37
+ ROOT = os.environ.get('COCO_DATA_DIR', '.')
38
+
39
+ PANOPTIC_COCO_CATS_FILE = f'{ROOT}/panoptic_coco_categories.json'
40
+ PANOPTIC_2017 = {
41
+ 'train': f'{ROOT}/panoptic_train2017.json',
42
+ 'validation': f'{ROOT}/panoptic_val2017.json',
43
+ }
44
+
45
+ PANOPTIC_GT_ZIP = {
46
+ 'train': f'{ROOT}/panoptic_train2017.zip',
47
+ 'validation': f'{ROOT}/panoptic_val2017.zip',
48
+ }
49
+
50
+
51
+ class Evaluator:
52
+ """Panoptic segmentation evaluator: calls official COCO API.
53
+
54
+ `predict_fn` accepts arbitrary dictionaries of parameters and data, where
55
+ the data dictionary is produced by the `pp` op. It is expected to output a
56
+ 2-channel mask, where the first channel encodes semantics, and the second
57
+ channel encodes instance ids.
58
+ """
59
+
60
+ def __init__(self,
61
+ predict_fn,
62
+ pp_fn,
63
+ batch_size,
64
+ dataset='coco/2017_panoptic',
65
+ dataset_dir=None,
66
+ split='validation',
67
+ predict_kwargs=None):
68
+ # Prepare to run predict on all processes and gather predictions on all
69
+ # devices. Note: if needed consider only gather across processes.
70
+ def predict(params, batch):
71
+ res = {
72
+ 'image/id': batch['image/id'],
73
+ 'mask': batch['mask'],
74
+ 'y': predict_fn(params, batch['input'], **(predict_kwargs or {})),
75
+ }
76
+ return jax.lax.all_gather(res, axis_name='data', axis=0)
77
+
78
+ self.predict_fn = jax.pmap(predict, axis_name='data')
79
+
80
+ # Prepare data for each process and pad with zeros so all processes have the
81
+ # same number of batches.
82
+ def preprocess(example):
83
+ return {
84
+ 'image/id': example['image/id'],
85
+ 'mask': tf.constant(1),
86
+ 'input': pp_builder.get_preprocess_fn(pp_fn)(example),
87
+ }
88
+
89
+ self.data = common.get_jax_process_dataset(
90
+ dataset, split, dataset_dir=dataset_dir,
91
+ global_batch_size=batch_size,
92
+ pp_fn=preprocess)
93
+
94
+ # Only process 0 runs conversion to png and calls into coco api.
95
+ if jax.process_index() == 0:
96
+ self.result_dir = tempfile.TemporaryDirectory()
97
+ (self.gt_folder, self.gt_json, self.categories_json,
98
+ self.remap, self.size_map) = _prepare_ground_truth(
99
+ dataset, split, dataset_dir)
100
+
101
+ def _compute_png_predictions(self, params):
102
+ """Computes predictions and converts then to png to optimize memory use."""
103
+ count = 0
104
+ logging.info('Panoptic eval: running inference.')
105
+ for batch in self.data.as_numpy_iterator():
106
+ out = self.predict_fn(params, batch)
107
+
108
+ if jax.process_index():
109
+ continue
110
+
111
+ out = jax.device_get(jax.tree_map(lambda x: x[0], out))
112
+ mask = out['mask']
113
+ pan_recs = out['y'][mask != 0]
114
+ ids = out['image/id'][mask != 0]
115
+
116
+ for pan_rec, image_id in zip(pan_recs, ids):
117
+ sem = pan_rec[..., 0]
118
+ ins = pan_rec[..., 1]
119
+
120
+ sem_remapped = np.array(sem)
121
+ for v in np.unique(sem):
122
+ sem_remapped[sem == v] = self.remap[v]
123
+ sem = sem_remapped
124
+
125
+ pan_mask = np.stack([sem, ins, np.zeros_like(sem)], axis=-1)
126
+ pan_mask = _resize_nearest(pan_mask, self.size_map[image_id])
127
+ pan_mask_png = tf.io.encode_png(pan_mask.astype('uint8')).numpy()
128
+
129
+ fname = f'{self.result_dir.name}/{image_id:012d}.png'
130
+ with open(fname, 'wb') as f:
131
+ f.write(pan_mask_png)
132
+ count += 1
133
+
134
+ logging.log_every_n_seconds(
135
+ logging.INFO, 'Panoptic eval: processed %i examples so far.', 30,
136
+ count)
137
+
138
+ if jax.process_index():
139
+ return None
140
+
141
+ logging.info('Panoptic eval: inference done. Processed %d examples.', count)
142
+ return self.result_dir
143
+
144
+ def run(self, params):
145
+ """Run eval."""
146
+ # Note result_dir is constant, but files inside are mutated.
147
+ result_dir = self._compute_png_predictions(params)
148
+
149
+ if not result_dir:
150
+ return
151
+
152
+ with tempfile.TemporaryDirectory() as pred_folder, \
153
+ tempfile.NamedTemporaryFile(mode='w') as pred_json:
154
+
155
+ logging.info('Panoptic eval: running conversion.')
156
+ converter.converter(
157
+ source_folder=result_dir.name,
158
+ images_json_file=self.gt_json,
159
+ categories_json_file=self.categories_json,
160
+ segmentations_folder=pred_folder,
161
+ predictions_json_file=pred_json.name)
162
+ logging.info('Panoptic eval: conversion done.')
163
+
164
+ logging.info('Panoptic eval: running metrics computation.')
165
+ res = pq_compute(gt_json_file=self.gt_json,
166
+ gt_folder=self.gt_folder,
167
+ pred_json_file=pred_json.name,
168
+ pred_folder=pred_folder)
169
+ logging.info('Panoptic eval: metrics computation done.')
170
+
171
+ for k in ['All', 'Stuff', 'Things']:
172
+ for m in ['pq', 'rq', 'sq']:
173
+ yield f'{k}_{m}', res[k][m]
174
+
175
+
176
+ def _prepare_ground_truth(dataset, split, data_dir):
177
+ """Prepare ground truth from tf.data.Dataset."""
178
+ if dataset == 'coco/2017_panoptic' and data_dir is None:
179
+ return _prepare_ground_truth_from_zipfiles(split)
180
+ else:
181
+ return _prepare_ground_truth_from_dataset(dataset, split, data_dir)
182
+
183
+
184
+ @functools.lru_cache(maxsize=None)
185
+ def _prepare_ground_truth_from_dataset(dataset, split, data_dir):
186
+ """Prepare ground truth from a tf.data.Dataset."""
187
+ dataset = tfds.builder(dataset, data_dir=data_dir).as_dataset(split=split)
188
+
189
+ categories_json = _make_local_copy(PANOPTIC_COCO_CATS_FILE)
190
+ with gfile.GFile(categories_json, 'rb') as f:
191
+ categories = json.loads(f.read())
192
+
193
+ # Build map from tfds class ids to COCO class ids.
194
+ remap = {0: 0}
195
+ with gfile.GFile(categories_json, 'r') as f:
196
+ remap = {**remap, **{(i + 1): x['id'] for i, x in enumerate(categories)}}
197
+
198
+ gt_folder = tempfile.mkdtemp()
199
+ gfile.makedirs(gt_folder)
200
+ size_map = {}
201
+ annotations = []
202
+ images = []
203
+ for example in dataset:
204
+ image_id = int(example['image/id'])
205
+ panoptic_image = example['panoptic_image']
206
+ ann_ids = example['panoptic_objects']['id']
207
+ ann_labels = example['panoptic_objects']['label']
208
+ ann_iscrowd = example['panoptic_objects']['is_crowd']
209
+ ann_area = example['panoptic_objects']['area']
210
+
211
+ fname = f'{image_id:012d}.png'
212
+ with gfile.GFile(os.path.join(gt_folder, fname), 'wb') as f:
213
+ f.write(tf.io.encode_png(panoptic_image).numpy())
214
+
215
+ size_map[image_id] = (panoptic_image.shape[0], panoptic_image.shape[1])
216
+
217
+ segments_info = []
218
+ for i in range(len(ann_ids)):
219
+ segments_info.append({
220
+ 'id': int(ann_ids[i]),
221
+ 'category_id': remap[int(ann_labels[i] + 1)],
222
+ 'iscrowd': int(ann_iscrowd[i]),
223
+ 'area': int(ann_area[i]),
224
+ })
225
+
226
+ annotations.append({
227
+ 'file_name': str(fname),
228
+ 'image_id': int(image_id),
229
+ 'segments_info': segments_info
230
+ })
231
+ images.append({
232
+ 'id': image_id,
233
+ 'file_name': f'{image_id:012d}.jpg',
234
+ })
235
+
236
+ # Write annotations.json needed for pq_compute.
237
+ gt_json = os.path.join(gt_folder, 'annotations.json')
238
+ with gfile.GFile(gt_json, 'wb') as f:
239
+ f.write(json.dumps({
240
+ 'images': images,
241
+ 'annotations': annotations,
242
+ 'categories': categories,
243
+ }))
244
+
245
+ return gt_folder, gt_json, categories_json, remap, size_map
246
+
247
+
248
+ def _prepare_ground_truth_from_zipfiles(split):
249
+ """Prepare ground truth from coco zip files."""
250
+ split_prefix = split.split('[')[0]
251
+ if split_prefix not in ('train', 'validation'):
252
+ raise ValueError(f'Split {split} not supported')
253
+
254
+ # The following 4 calls are cached. This allows to save significant time
255
+ # in use cases like sweeping predict_fn hparams on the same run.
256
+ gt_json = _make_local_copy(PANOPTIC_2017[split_prefix])
257
+ gt_folder = _make_local_unzip_copy(PANOPTIC_GT_ZIP[split_prefix])
258
+ categories_json = _make_local_copy(PANOPTIC_COCO_CATS_FILE)
259
+ image_ids = _list_image_ids('coco/2017_panoptic', split)
260
+
261
+ gt_folder = os.path.join(
262
+ gt_folder, 'panoptic_val2017'
263
+ if split_prefix == 'validation' else 'panoptic_train2017')
264
+
265
+ # Build map from tfds class ids to COCO class ids.
266
+ remap = {0: 0}
267
+ with gfile.GFile(categories_json, 'r') as f:
268
+ remap = {**remap, **{(i + 1): x['id'] for i, x in enumerate(json.load(f))}}
269
+
270
+ # Filters gt_json to contain only annotations for images in dataset.
271
+ with gfile.GFile(gt_json) as f:
272
+ data = json.load(f)
273
+ logging.info(
274
+ 'Panoptic eval: pre-filter %d annotations.',
275
+ len(data['annotations'])
276
+ )
277
+ data['images'] = [x for x in data['images'] if x['id'] in image_ids]
278
+ data['annotations'] = [
279
+ x for x in data['annotations'] if x['image_id'] in image_ids
280
+ ]
281
+ logging.info(
282
+ 'Panoptic eval: post-filter %d annotations.',
283
+ len(data['annotations'])
284
+ )
285
+ filtered_gt_json = tempfile.NamedTemporaryFile(delete=False).name
286
+ with open(filtered_gt_json, 'w') as f:
287
+ json.dump(data, f)
288
+
289
+ # Precompute images sizes.
290
+ size_map = {x['id']: (x['height'], x['width']) for x in data['images']}
291
+
292
+ return gt_folder, filtered_gt_json, categories_json, remap, size_map
293
+
294
+
295
+ @functools.lru_cache(maxsize=None)
296
+ def _list_image_ids(dataset, split):
297
+ d = tfds.load(dataset, split=split).map(lambda x: x['image/id'])
298
+ return frozenset(d.as_numpy_iterator())
299
+
300
+
301
+ @functools.lru_cache(maxsize=None)
302
+ def _make_local_copy(fname) -> str:
303
+ start = time.monotonic()
304
+ local_file = tempfile.NamedTemporaryFile(delete=False)
305
+ gfile.copy(fname, local_file.name, overwrite=True)
306
+ logging.info('Copy %s in %d seconds.', fname, time.monotonic() - start)
307
+ return local_file.name
308
+
309
+
310
+ @functools.lru_cache(maxsize=None)
311
+ def _make_local_unzip_copy(fname) -> str:
312
+ start = time.monotonic()
313
+ folder = tempfile.mkdtemp()
314
+ with tempfile.NamedTemporaryFile() as tmp_zip_file:
315
+ gfile.copy(fname, tmp_zip_file.name, overwrite=True)
316
+ with zipfile.ZipFile(tmp_zip_file.name, 'r') as f:
317
+ f.extractall(folder)
318
+ logging.info('Copy %s in %d seconds.', fname, time.monotonic() - start)
319
+ return folder
320
+
321
+
322
+ @partial(jax.jit, static_argnums=(1,), backend='cpu')
323
+ def _resize_nearest(image, shape):
324
+ return jax.image.resize(image, shape + image.shape[-1:], 'nearest')
Tipsomaly/model/big_vision/evaluators/proj/uvim/coltran_fid.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Evaluation producing ColTran FID-5K metric."""
16
+
17
+ import functools
18
+ import os
19
+
20
+ from absl import logging
21
+ import einops
22
+ import jax
23
+ import numpy as np
24
+ import tensorflow as tf
25
+ import tensorflow_datasets as tfds
26
+ import tensorflow_gan as tfgan
27
+ import tensorflow_hub as tfhub
28
+
29
+ from tensorflow.io import gfile
30
+
31
+
32
+ ROOT = os.environ.get("FID_DATA_DIR", ".")
33
+
34
+
35
+ def _preprocess(image, resolution=512):
36
+ """ColTran dataset preprocessing.
37
+
38
+ See,
39
+ github.com/google-research/google-research/blob/master/coltran/datasets.py#L44
40
+
41
+ Args:
42
+ image: ImageNet example from TFDS.
43
+ resolution: Integer representing output size.
44
+
45
+ Returns:
46
+ An int32 image of size (resolution, resolution, 3).
47
+ """
48
+ image_shape = tf.shape(image)
49
+ height, width = image_shape[0], image_shape[1]
50
+ side_size = tf.minimum(height, width)
51
+ image = tf.image.resize_with_crop_or_pad(
52
+ image, target_height=side_size, target_width=side_size)
53
+ image = tf.image.resize(image, method="area", antialias=True,
54
+ size=(resolution, resolution))
55
+ image = tf.cast(tf.round(image), dtype=tf.int32)
56
+ return image
57
+
58
+
59
+ def _normalize(x):
60
+ """Coltran normalization to expected range for Inception module.
61
+
62
+ Args:
63
+ x: Image with values in [0,255].
64
+
65
+ Returns:
66
+ Image with values in [-1,1].
67
+ """
68
+ x = tf.cast(x, tf.float32)
69
+ x = (x / 128.0) - 1.0 # note: 128.0 is the value used in ColTran.
70
+ return x
71
+
72
+
73
+ class Evaluator:
74
+ """ColTran FID-5K Evaluator.
75
+
76
+ This Evaluator aims to mirror the evaluation pipeline used by Kumar et.al.
77
+ in Colorization Transformer (https://arxiv.org/abs/2102.04432).
78
+
79
+ To be clear: much of this code is direct snippets from ColTran code.
80
+
81
+ See,
82
+ github.com/google-research/google-research/blob/master/coltran/datasets.py#L44
83
+
84
+ The ColTran pipeline has numerous stages, where serialied data is passed
85
+ between binaries via file, etc... While we don't physically write the same
86
+ files, we simulate the effects of the serialization (e.g., quantization).
87
+ """
88
+
89
+ def __init__(self,
90
+ predict_fn,
91
+ batch_size, # ignored
92
+ device_batch_size=5,
93
+ coltran_seed=1,
94
+ predict_kwargs=None):
95
+ """Create Evaluator.
96
+
97
+ Args:
98
+ predict_fn: Colorization prediction function. Expects grayscale images
99
+ of size (512, 512, 3) in keys `image` and `image_ctx` with values in
100
+ the range [-1,1]. Outputs `color` image in range [-1,1].
101
+ batch_size: ignored.
102
+ device_batch_size: number of images per batch, per device.
103
+ coltran_seed: used to specify the block of 5_000 images used to generate
104
+ the reference pool. Value of `1` matches default ColTran code.
105
+ predict_kwargs: arguments passed to `predict_fn`.
106
+ """
107
+ del batch_size
108
+
109
+ self.num_devices = jax.local_device_count()
110
+ self.device_batch_size = device_batch_size
111
+ logging.log(logging.INFO, "Colorizing with batch size %i on %i devices.",
112
+ self.device_batch_size, self.num_devices)
113
+ assert 5_000 % (self.device_batch_size * self.num_devices) == 0
114
+
115
+ predict = functools.partial(predict_fn, **(predict_kwargs or {}))
116
+ self.predict_fn = jax.pmap(predict)
117
+
118
+ module = tfhub.load(tfgan.eval.INCEPTION_TFHUB)
119
+ def _pools(x):
120
+ return np.squeeze(module(x)[tfgan.eval.INCEPTION_FINAL_POOL].numpy())
121
+
122
+ self.inception_pool = _pools
123
+
124
+ # Setup the colorization dataset.
125
+ # TRICKY: ColTran FID-5k uses the first 5_000 images returned as read by
126
+ # default from tensorflow_datasets (that is: with shard interleaving).
127
+ # In particular note that it is different than the set of images returned
128
+ # by "validation[:5000]".
129
+ def _eval_data_preprocess(example):
130
+ # Colorization happens at 512x512 resolution.
131
+ image = _preprocess(example["image"], resolution=512)
132
+ image = _normalize(image)
133
+ grayscale = tf.image.grayscale_to_rgb(tf.image.rgb_to_grayscale(image))
134
+ return {
135
+ "image": image,
136
+ "grayscale": grayscale,
137
+ "file_name": example["file_name"]
138
+ }
139
+
140
+ ds = tfds.load("imagenet2012", split="validation")
141
+ ds = ds.map(_eval_data_preprocess)
142
+ ds = ds.take(5_000)
143
+ ds = ds.batch(self.device_batch_size)
144
+ ds = ds.batch(self.num_devices)
145
+ self.eval_data = ds.cache().prefetch(tf.data.AUTOTUNE)
146
+
147
+ # Setup the reference dataset.
148
+ def _reference_data_preprocess(example):
149
+ # ColTran eval operates on 256x256.
150
+ image = _preprocess(example["image"], resolution=256)
151
+ image = _normalize(image)
152
+ return {"image": image, "file_name": example["file_name"]}
153
+
154
+ ds = tfds.load("imagenet2012", split="validation")
155
+ ds = ds.map(_reference_data_preprocess)
156
+ # Skip the images used in colorization.
157
+ ds = ds.skip(5_000)
158
+ # ColTran eval w/ seed=1 effectively uses 10_000:15_000 to
159
+ # calculate reference.
160
+ ds = ds.skip(coltran_seed * 5_000)
161
+ ds = ds.take(5_000)
162
+ ds = ds.batch(device_batch_size)
163
+ self.reference_data = ds.cache().prefetch(tf.data.AUTOTUNE)
164
+
165
+ def _get_file(name):
166
+ return os.path.join(ROOT, name)
167
+
168
+ with gfile.GFile(_get_file("eval_file_names.txt")) as f:
169
+ self.eval_file_names = frozenset(f.read().splitlines())
170
+
171
+ with gfile.GFile(_get_file("reference_file_names.txt")) as f:
172
+ self.reference_file_names = frozenset(f.read().splitlines())
173
+
174
+ def run(self, params):
175
+ """Run eval."""
176
+
177
+ if jax.process_index(): # Host0 does all work.
178
+ return
179
+
180
+ color_pools = []
181
+ color_file_names = set()
182
+ for i, batch in enumerate(self.eval_data.as_numpy_iterator()):
183
+ predict_batch = {
184
+ "labels": batch["image"],
185
+ "image": batch["grayscale"],
186
+ "image_ctx": batch["grayscale"],
187
+ }
188
+ y = self.predict_fn(params, predict_batch)
189
+ y = y["color"]
190
+ y = einops.rearrange(y, "d b h w c -> (d b) h w c")
191
+
192
+ # Return to the ColTran eval size of 256x256.
193
+ y = tf.image.resize(y, (256, 256), "area")
194
+
195
+ # Mimic effect of serializing image as integers and map back to [-1, 1].
196
+ y = np.clip(np.round((y + 1.) * 128.), 0, 255)
197
+ y = _normalize(y)
198
+
199
+ color_pools.append(self.inception_pool(y))
200
+
201
+ file_names = einops.rearrange(batch["file_name"], "d b -> (d b)")
202
+ color_file_names.update([f.decode() for f in file_names])
203
+
204
+ logging.log_every_n_seconds(
205
+ logging.INFO,
206
+ "ColTran FID eval: processed %i colorized examples so far.", 30,
207
+ (i + 1) * self.device_batch_size * self.num_devices)
208
+
209
+ reference_pools = []
210
+ reference_file_names = set()
211
+ for i, batch in enumerate(self.reference_data.as_numpy_iterator()):
212
+ image = batch["image"]
213
+ assert np.array_equal(image.shape, (self.device_batch_size, 256, 256, 3))
214
+ reference_pools.append(self.inception_pool(image))
215
+ reference_file_names.update([f.decode() for f in batch["file_name"]])
216
+
217
+ logging.log_every_n_seconds(
218
+ logging.INFO,
219
+ "ColTran FID eval: processed %i reference examples so far.", 30,
220
+ (i + 1) * self.device_batch_size)
221
+
222
+ if color_file_names != self.eval_file_names:
223
+ raise ValueError("unknown: {}\nmissing: {}".format(
224
+ color_file_names - self.eval_file_names,
225
+ self.eval_file_names - color_file_names))
226
+
227
+ if reference_file_names != self.reference_file_names:
228
+ raise ValueError("unknown: {}\nmissing: {}".format(
229
+ reference_file_names - self.reference_file_names,
230
+ self.reference_file_names - reference_file_names))
231
+
232
+ color = np.concatenate(color_pools, axis=0)
233
+ reference = np.concatenate(reference_pools, axis=0)
234
+
235
+ if color.shape[0] != 5_000:
236
+ raise ValueError(color.shape)
237
+
238
+ if reference.shape[0] != 5_000:
239
+ raise ValueError(reference.shape)
240
+
241
+ yield "FID_5k", tfgan.eval.frechet_classifier_distance_from_activations(
242
+ color, reference)
Tipsomaly/model/big_vision/evaluators/proj/uvim/coltran_fid_data/eval_file_names.txt ADDED
The diff for this file is too large to render. See raw diff
 
Tipsomaly/model/big_vision/evaluators/proj/uvim/coltran_fid_data/reference_file_names.txt ADDED
The diff for this file is too large to render. See raw diff
 
Tipsomaly/model/big_vision/evaluators/proj/uvim/common.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Common utilities used in evaluators."""
16
+ import math
17
+ import jax
18
+ import tensorflow as tf
19
+ import tensorflow_datasets as tfds
20
+
21
+
22
+ def get_jax_process_dataset(dataset, split, global_batch_size, pp_fn,
23
+ dataset_dir=None, cache=True, add_tfds_id=False):
24
+ """Returns dataset to be processed by current jax host.
25
+
26
+ The dataset is sharded and padded with zeros such that all processes
27
+ have equal number of batches. The first 2 dimensions of the dataset
28
+ elements are: [local_device_count, device_batch_size].
29
+
30
+ Args:
31
+ dataset: dataset name.
32
+ split: dataset split.
33
+ global_batch_size: batch size to be process per iteration on the dataset.
34
+ pp_fn: preprocessing function to apply per example.
35
+ dataset_dir: path for tfds to find the prepared data.
36
+ cache: whether to cache the dataset after batching.
37
+ add_tfds_id: whether to add the unique `tfds_id` string to each example.
38
+ """
39
+ assert global_batch_size % jax.device_count() == 0
40
+ total_examples = tfds.load(
41
+ dataset, split=split, data_dir=dataset_dir).cardinality()
42
+ num_batches = math.ceil(total_examples / global_batch_size)
43
+
44
+ process_split = tfds.even_splits(
45
+ split, n=jax.process_count(), drop_remainder=False)[jax.process_index()]
46
+ data = tfds.load(
47
+ dataset,
48
+ split=process_split,
49
+ data_dir=dataset_dir,
50
+ read_config=tfds.ReadConfig(add_tfds_id=add_tfds_id)).map(pp_fn)
51
+ pad_data = tf.data.Dataset.from_tensors(
52
+ jax.tree_map(lambda x: tf.zeros(x.shape, x.dtype), data.element_spec)
53
+ ).repeat()
54
+
55
+ data = data.concatenate(pad_data)
56
+ data = data.batch(global_batch_size // jax.device_count())
57
+ data = data.batch(jax.local_device_count())
58
+ data = data.take(num_batches)
59
+ if cache:
60
+ # Eval datasets are often used many times and caching the dataset after
61
+ # batching allows one to have the buffers ready to be used and not have
62
+ # to wait for preprocessing to be done over and over.
63
+ data = data.cache()
64
+ return data
Tipsomaly/model/big_vision/evaluators/proj/uvim/compute_mean.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Evaluator for computing mean of per-example metrics."""
16
+ import functools
17
+ from typing import Mapping
18
+
19
+ from big_vision import input_pipeline
20
+ from big_vision.datasets import core as ds_core
21
+ from big_vision.pp import builder as pp_builder
22
+
23
+ import jax
24
+ import jax.numpy as jnp
25
+ import numpy as np
26
+
27
+
28
+ # Note: global to avoid jax re-compiling across different evaluator instances.
29
+ @functools.partial(jax.pmap, static_broadcasted_argnums=0, axis_name='batch')
30
+ def _run_predict_fn(predict_fn, params, batch):
31
+ """Sum per-example metrics weighted by `_mask`."""
32
+ mask = batch['_mask']
33
+ metrics = predict_fn(params, batch)
34
+ # Sanity check output format of predict_fn.
35
+ assert isinstance(metrics, Mapping), 'predict_fn must return a dict'
36
+ for y in jax.tree_leaves(metrics):
37
+ if y.shape != mask.shape:
38
+ raise ValueError(
39
+ f'Expected per-example metrics of shape {mask.shape} found '
40
+ f'{jax.tree_map(lambda x: x.shape, metrics)}.')
41
+ metrics = {**metrics, '_mask': mask}
42
+ metrics = jax.tree_map(lambda x: jnp.inner(x, mask), metrics)
43
+ return jax.lax.psum(metrics, axis_name='batch')
44
+
45
+
46
+ class Evaluator:
47
+ """Report the mean of per-example metrics computed by predict_fn.
48
+
49
+ `predict_fn(params, batch)` must return a dict from metric name to
50
+ per-example metrics of shape [batch_size].
51
+ """
52
+
53
+ def __init__(self, predict_fn, data, pp_fn, batch_size,
54
+ cache_final=True, cache_raw=False, prefetch=1):
55
+ data = ds_core.get(**data)
56
+ self.dataset, self.steps = input_pipeline.make_for_inference(
57
+ data.get_tfdata(ordered=True), batch_size=batch_size,
58
+ num_ex_per_process=data.num_examples_per_process(),
59
+ preprocess_fn=pp_builder.get_preprocess_fn(pp_fn),
60
+ cache_final=cache_final, cache_raw=cache_raw)
61
+ self.data_iter = input_pipeline.start_input_pipeline(self.dataset, prefetch)
62
+ self.predict_fn = predict_fn
63
+
64
+ def run(self, params):
65
+ """Computes all metrics."""
66
+ metrics = []
67
+
68
+ # Compute batch metrics without blocking.
69
+ for _, batch in zip(range(self.steps), self.data_iter):
70
+ batch_metrics = _run_predict_fn(self.predict_fn, params, batch)
71
+ metrics.append(batch_metrics)
72
+
73
+ # Transfer metrics from device 0 to host (blocking).
74
+ metrics = jax.device_get(jax.tree_map(lambda x: x[0], metrics))
75
+
76
+ metrics_sum = jax.tree_map(lambda *x: np.sum(x), *metrics)
77
+ mask_sum = metrics_sum.pop('_mask')
78
+ for key, value_sum in metrics_sum.items():
79
+ yield (key, value_sum / mask_sum)
Tipsomaly/model/big_vision/evaluators/proj/uvim/nyu_depth.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Evaluation for NYU depth.
16
+
17
+ At evaluation time the ground truth is cropped and clipped. Values outside of
18
+ the test crop or clipping range are not included in eval calculations.
19
+
20
+ In this evaluator, it is assume that the groud truth is already cropped, so the
21
+ entire image is evaluated. However, the evaluator does perform the clipping.
22
+
23
+ Reference implementations:
24
+ https://github.com/zhyever/Monocular-Depth-Estimation-Toolbox/blo(internal link)a0f341244260ff61541191a613dd74bc/depth/datasets/nyu.py
25
+ https://github.com/vinvino02/GLPDepth/blob/7f3c78df4ecd6e7c79fd0c4b73c95d61f4aa2121/code/utils/metrics.py
26
+ https://github.com/shariqfarooq123/AdaBins/blob/2fb686a66a304f0a719bc53d77412460af97fd61/evaluate.py
27
+ """
28
+
29
+ import functools
30
+
31
+ import big_vision.evaluators.proj.uvim.common as common
32
+ import big_vision.pp.builder as pp_builder
33
+ import jax
34
+ import jax.numpy as jnp
35
+ import numpy as np
36
+ import tensorflow as tf
37
+
38
+ EVAL_CROP_H = 426
39
+ EVAL_CROP_W = 560
40
+
41
+
42
+ class Evaluator:
43
+ """Evaluator for NYU depth."""
44
+
45
+ def __init__(self,
46
+ predict_fn,
47
+ pp_fn,
48
+ batch_size,
49
+ dataset,
50
+ split,
51
+ min_depth=1e-3,
52
+ max_depth=10,
53
+ dataset_dir=None,
54
+ predict_kwargs=None):
55
+ self.min_depth = min_depth
56
+ self.max_depth = max_depth
57
+
58
+ def predict(params, batch):
59
+ pred = predict_fn(params, batch, **(predict_kwargs or {}))
60
+
61
+ return jax.lax.all_gather({
62
+ "mask": batch["mask"],
63
+ "gt": jnp.squeeze(batch["ground_truth"], axis=-1),
64
+ "y": pred["depth"],
65
+ }, axis_name="data", axis=0)
66
+
67
+ self.predict_fn = jax.pmap(predict, axis_name="data")
68
+
69
+ # Prepare data for each process and pad with zeros so all processes have the
70
+ # same number of batches.
71
+ def preprocess(example):
72
+ return {
73
+ "mask": tf.constant(1),
74
+ **pp_builder.get_preprocess_fn(pp_fn)(example),
75
+ }
76
+
77
+ self.process_batch_size = batch_size // jax.process_count()
78
+
79
+ self.data = common.get_jax_process_dataset(
80
+ dataset=dataset,
81
+ dataset_dir=dataset_dir,
82
+ split=split,
83
+ global_batch_size=batch_size,
84
+ pp_fn=preprocess)
85
+
86
+ def run(self, params):
87
+ """Run eval."""
88
+ # Assumes that the ground truth is processed by the eval crop.
89
+ eval_mask = np.ones((EVAL_CROP_H, EVAL_CROP_W), dtype=np.bool_)
90
+ rmses = []
91
+ abs_res = []
92
+ abs_logs = []
93
+ d1s = []
94
+ d2s = []
95
+ d3s = []
96
+ for batch in self.data.as_numpy_iterator():
97
+ # Outputs is a dict with values shaped (gather/same, devices, batch, ...)
98
+ out = self.predict_fn(params, batch)
99
+
100
+ if jax.process_index(): # Host0 gets all preds and does eval.
101
+ continue
102
+
103
+ # First, we remove the "gather" dim and transfer the result to host,
104
+ # leading to numpy arrays of (devices, device_batch, ...)
105
+ out = jax.tree_map(lambda x: jax.device_get(x[0]), out)
106
+ # Then the bool-indexing with mask resulting in flat (global_batch, ...)
107
+ out = jax.tree_map(lambda x: x[out["mask"] == 1], out) # pylint:disable=cell-var-from-loop
108
+
109
+ for gt, pred in zip(out["gt"], out["y"]):
110
+ pred = _resize_nearest(pred, (EVAL_CROP_H, EVAL_CROP_W))
111
+ valid_mask = np.logical_and(gt > self.min_depth, gt < self.max_depth)
112
+ valid_mask = np.logical_and(valid_mask, eval_mask)
113
+
114
+ rmses.append(_compute_rmse(gt[valid_mask], pred[valid_mask]))
115
+ abs_res.append(_compute_abs_re(gt[valid_mask], pred[valid_mask]))
116
+ abs_logs.append(_compute_abs_log(gt[valid_mask], pred[valid_mask]))
117
+ d1s.append(_compute_delta(gt[valid_mask], pred[valid_mask], order=1))
118
+ d2s.append(_compute_delta(gt[valid_mask], pred[valid_mask], order=2))
119
+ d3s.append(_compute_delta(gt[valid_mask], pred[valid_mask], order=3))
120
+
121
+ if jax.process_index(): # Host0 gets all preds and does eval.
122
+ return
123
+
124
+ yield "RMSE", np.mean(rmses)
125
+ yield "abs_RE", np.mean(abs_res)
126
+ yield "log10", np.mean(abs_logs)
127
+ yield "delta1", np.mean(d1s)
128
+ yield "delta2", np.mean(d2s)
129
+ yield "delta3", np.mean(d3s)
130
+
131
+
132
+ @functools.partial(jax.jit, static_argnums=(1,), backend="cpu")
133
+ def _resize_nearest(image, shape):
134
+ return jax.image.resize(image, shape, "nearest")
135
+
136
+
137
+ def _compute_rmse(gt, pred):
138
+ diff = gt - pred
139
+ return np.sqrt(np.mean(np.power(diff, 2)))
140
+
141
+
142
+ def _compute_abs_re(gt, pred):
143
+ diff = np.abs(gt - pred)
144
+ return np.mean(diff / gt)
145
+
146
+
147
+ def _compute_abs_log(gt, pred):
148
+ diff = np.abs(np.log10(gt) - np.log10(pred))
149
+ return np.mean(diff)
150
+
151
+
152
+ def _compute_delta(gt, pred, order):
153
+ rel_diff = np.maximum(gt / pred, pred / gt)
154
+ return np.sum(rel_diff < 1.25**order) / rel_diff.size