File size: 1,469 Bytes
2cc98e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import pandas as pd
from pyspark.sql.session import SparkSession


EMBEDDINGS_SQL = """
SELECT DISTINCT
  e.item_id,
  s.seriesMainTitle AS title,
  e.embedding
FROM
  {item_embeddings} e
INNER JOIN
  {s2s_content_entities} s
  ON e.item_id = s.unpackedValue.series.id.id
INNER JOIN
  {series_offering_dim} off
ON
  s.unpackedValue.series.id.id = off.series_id
  AND SIZE(ARRAY_INTERSECT(off.country_codes, ARRAY('US'))) > 0
WHERE
  date = (SELECT MAX(date) FROM bolt_recs_prod.gold.item_embeddings)
  AND n_dimensions = 1536
"""


class EmbeddingsDataLoader:
    def __init__(
        self,
        env: str = "prod",
        spark_session: SparkSession = None,
        item_embeddings: str = "bolt_recs_prod.gold.item_embeddings",
        s2s_content_entities: str = "bolt_cep_prod.gold.s2s_content_entities",
        series_offering_dim: str = "bolt_dai_ckg_prod.gold.series_offering_dim",
    ):
        if not spark_session:
            spark_session = SparkSession.builder.appName(
                "embeddings"
            ).getOrCreate()

        self._env = env
        self._spark_session = spark_session
        self._table_names = {
            "item_embeddings": item_embeddings,
            "s2s_content_entities": s2s_content_entities,
            "series_offering_dim": series_offering_dim,
        }

    def load(self) -> pd.DataFrame:
        query = EMBEDDINGS_SQL.format(**self._table_names)
        return self._spark_session.sql(query).toPandas()