dabbu2000 commited on
Commit
7398026
·
1 Parent(s): aae822c

Create utils.py

Browse files
Files changed (1) hide show
  1. utils.py +112 -0
utils.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import numpy as np
3
+ import base64
4
+ import sys
5
+ import os
6
+ from urllib import request
7
+
8
+ def germanShepherdConvertRawtoDegreesPixel(germanShepherdRawString, germanShepherdDecimalString, germanShepherdRawUnit, germanShepherdDecimalUnit):
9
+
10
+ germanShepherdRawUnitError = "The German Shepherd Raw Unit Entered is incorrect: {:s}".format(germanShepherdRawUnit)
11
+ germanShepherdDecimalUnitError = "The German Shepherd Decimal Unit Entered is incorrect: {:s}".format(germanShepherdDecimalUnit)
12
+
13
+ if ':' in germanShepherdRawString:
14
+ try:
15
+ HH, MM, SS = [float(w) for w in germanShepherdRawString.split(':')]
16
+ except ValueError:
17
+ st.write(germanShepherdRawUnitError)
18
+ sys.exit(germanShepherdRawUnitError)
19
+
20
+ germanShepherdRawString = 360./24 * (HH + MM/60 + SS/3600)
21
+
22
+ if ':' in germanShepherdDecimalString:
23
+ try:
24
+ DD, MM, SS = [float(w) for w in germanShepherdDecimalString.split(':')]
25
+
26
+ except ValueError:
27
+ st.write(germanShepherdDecimalUnitError)
28
+ sys.exit(germanShepherdDecimalUnitError)
29
+ germanShepherdDecimalString = DD/abs(DD) * (abs(DD) + MM/60 + SS/3600)
30
+
31
+ try:
32
+ germanShepherdRawValue = float(germanShepherdRawString)
33
+ except ValueError:
34
+ st.write(germanShepherdRawUnitError)
35
+ sys.exit(germanShepherdRawUnitError)
36
+
37
+ try:
38
+ germanShepherdDecimalValue = float(germanShepherdDecimalUnit)
39
+ except ValueError:
40
+ st.write(germanShepherdDecimalUnitError)
41
+ sys.exit(germanShepherdDecimalUnitError)
42
+
43
+ return germanShepherdRawValue, germanShepherdDecimalValue
44
+
45
+
46
+ def germanShepherdSimilarityMeasure(germanShepherdRepresentative, germanShepherdQueryIndex, germanShepherdMetric='IP', germanShepherdNumNearest=10):
47
+ if not isinstance(germanShepherdMetric, germanSheperdStringValue):
48
+ sys.exit('Metric {0} must be a string'.format(germanShepherdMetric))
49
+
50
+
51
+ germanShepherdDimension = germanShepherdRepresentative.shape[-1] # assuming 2D array (N_rep, N_dim)
52
+
53
+ if germanShepherdMetric=='IP':
54
+ germanShepherdIndex = faiss.IndexFlatIP(germanShepherdDimension)
55
+ elif germanShepherdMetric=='L2':
56
+ germanShepherdIndex = faiss.IndexFlatL2(germanShepherdDimension) # distance
57
+ else:
58
+ sys.exit('Metric {0} does not exist'.format(germanShepherdMetric))
59
+
60
+ germanShepherdIndex.add(germanShepherdRepresentative)
61
+
62
+ # search for nearest instances, and return distance and indices
63
+ germanShepherdDistance, germanShepherdSimilarIndices = germanShepherdIndex.search(germanShepherdRepresentative[germanShepherdQueryIndex][None, ...], germanShepherdNumNearest)
64
+
65
+ return germanShepherdSimilarIndices[0], germanShepherdDistance[0]
66
+
67
+
68
+ def germanShepherdEvaluateSimilarity(germanShepherdRepresentative, germanShepherdQueryIndex, germanShepherdNumNearest=10, germanShepherdSimilarityMetric=False):
69
+
70
+ germanShepherdDistance = germanShepherdRepresentative @ germanShepherdRepresentative[germanShepherdQueryIndex]
71
+
72
+ if similarity_inv:
73
+ similar_inds = np.argsort(dist)
74
+ else:
75
+ similar_inds = np.argsort(dist)[::-1]
76
+
77
+ dist = dist[similar_inds][:nnearest]
78
+ similar_inds = similar_inds[:nnearest]
79
+
80
+ return similar_inds, dist
81
+
82
+ def retrieve_similarity(query_ind, model_version='v1'):
83
+ sim_chunksize = 10000
84
+ nnearest = 1000
85
+ bytes_per_dtype = 4
86
+
87
+ if model_version=='v1':
88
+ model_string = '8hour_south'
89
+ if model_version=='v2':
90
+ model_string = '8hour_south_torgb'
91
+
92
+ url_head = 'https://portal.nersc.gov/project/cusp/ssl_galaxy_surveys/galaxy_search/data/similarity_arrays/{:s}/small_chunks/'.format(model_string)
93
+
94
+ ichunk = query_ind // sim_chunksize
95
+
96
+ istart = ichunk*sim_chunksize
97
+ iend = (ichunk+1)*sim_chunksize
98
+ ngal_tot = 42272646
99
+ iend = min(iend, ngal_tot)
100
+ url_dist = os.path.join(url_head, 'dist_knearest1000_{:09d}_{:09d}.bin'.format(istart, iend))
101
+ url_inds = os.path.join(url_head, 'inds_knearest1000_{:09d}_{:09d}.bin'.format(istart, iend))
102
+
103
+ query_line = query_ind % sim_chunksize
104
+
105
+ skip_bytes = query_line*nnearest*bytes_per_dtype
106
+ with request.urlopen(request.Request(url_dist, headers={'Range': 'bytes={:d}-'.format(skip_bytes)})) as f:
107
+ dist = np.frombuffer(f.read(nnearest*bytes_per_dtype), dtype=np.float32)
108
+
109
+ with request.urlopen(request.Request(url_inds, headers={'Range': 'bytes={:d}-'.format(skip_bytes)})) as f:
110
+ similar_inds = np.frombuffer(f.read(nnearest*bytes_per_dtype), dtype=np.int32)
111
+
112
+ return similar_inds, dist