# -*- coding: utf-8 -*- import os import json import re from itertools import chain, islice import numpy as np from gensim.models import Word2Vec from tqdm import tqdm import faiss import gradio as gr from sklearn.metrics.pairwise import cosine_similarity from huggingface_hub import hf_hub_download, login from pyspark.sql import SparkSession from pyspark.sql.functions import col, udf, monotonically_increasing_id, collect_list, concat_ws from pyspark.sql.types import StringType from huggingface_hub import HfApi # Load token from Hugging Face Secrets HF_TOKEN = os.environ.get("RedditSemanticSearch") # Define target subreddits target_subreddits = ["askscience", "gaming", "technology", "todayilearned", "programming"] # Function to stream JSONL Reddit files from HF Hub def load_reddit_split(subreddit_name): file_path = hf_hub_download( repo_id="HuggingFaceGECLM/REDDIT_comments", filename=f"{subreddit_name}.jsonl" ) with open(file_path, "r") as f: for line in f: yield json.loads(line) # Combine subreddit data combined_dataset = chain(*(load_reddit_split(sub) for sub in target_subreddits)) if "JAVA_HOME" not in os.environ: os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-11-openjdk-amd64" # Common path on HF Spaces/Debian # PySpark session spark = SparkSession.builder.getOrCreate() df = spark.createDataFrame([{"body": ex["body"]} for ex in islice(combined_dataset, 100000)]) # Clean text function def clean_body(text): text = text.lower() text = re.sub(r"http\S+|www\S+|https\S+", "", text) text = re.sub(r"[^a-zA-Z\s]", "", text) return re.sub(r"\s+", " ", text).strip() clean_udf = udf(clean_body, StringType()) df_clean = df.withColumn("clean", clean_udf(col("body"))) # Chunking chunk_size = 5 df_indexed = df_clean.withColumn("row_num", monotonically_increasing_id()) df_indexed = df_indexed.withColumn("chunk_id", (col("row_num") / chunk_size).cast("int")) df_chunked = df_indexed.groupBy("chunk_id").agg(concat_ws(" ", collect_list("clean")).alias("chunk_text")) # Collect for embedding chunked_comments = df_chunked.select("chunk_text").rdd.map(lambda x: x[0]).collect() # Create subreddit labels combined_dataset = chain(*(load_reddit_split(sub) for sub in target_subreddits)) subreddit_labels = [] for example in combined_dataset: subreddit_labels.append(example["subreddit_name_prefixed"]) if len(subreddit_labels) >= len(chunked_comments): break # Tokenize def clean_text(text): text = text.lower() text = re.sub(r"http\S+|www\S+|https\S+", "", text, flags=re.MULTILINE) text = re.sub(r"[^a-zA-Z\s]", "", text) text = re.sub(r"\s+", " ", text).strip() return text tokenized_chunks = [] for chunk in tqdm(chunked_comments): cleaned = clean_text(chunk) tokens = cleaned.split() tokenized_chunks.append(tokens) # Train Word2Vec model = Word2Vec(sentences=tokenized_chunks, vector_size=100, window=5, min_count=2, workers=4, sg=1) model.save("reddit_word2vec.model") # Embedding function def get_chunk_embedding(chunk_tokens, model): vectors = [model.wv[token] for token in chunk_tokens if token in model.wv] if not vectors: return np.zeros(model.vector_size) return np.mean(vectors, axis=0) chunk_embeddings = [get_chunk_embedding(tokens, model) for tokens in tokenized_chunks] embedding_matrix = np.array(chunk_embeddings).astype("float32") # Build FAISS index index = faiss.IndexFlatL2(model.vector_size) index.add(embedding_matrix) faiss.write_index(index, "reddit_faiss.index") # Load model and index for search API model = Word2Vec.load("reddit_word2vec.model") index = faiss.read_index("reddit_faiss.index") subreddit_map = {i: label for i, label in enumerate(subreddit_labels)} unique_subreddits = sorted(set(subreddit_labels)) original_chunks = [" ".join(tokens) for tokens in tokenized_chunks] # Search function def embed_text(text): tokens = text.lower().split() vectors = [model.wv[token] for token in tokens if token in model.wv] return np.mean(vectors, axis=0) if vectors else np.zeros(model.vector_size) def search_reddit(query, selected_subreddit, top_k=5): query_vec = embed_text(query).astype("float32").reshape(1, -1) D, I = index.search(query_vec, top_k * 2) results = [] for idx in I[0]: if idx < len(chunked_comments) and subreddit_map[idx] == selected_subreddit: results.append(f"🔸 {chunked_comments[idx]}") if len(results) >= top_k: break return "\n\n".join(results) if results else "⚠️ No relevant results found." # Gradio UI with gr.Blocks(theme=gr.themes.Base(primary_hue="orange")) as demo: gr.Image(value="https://1000logos.net/wp-content/uploads/2017/05/Reddit-Logo.png", show_label=False, height=100) gr.Markdown("## Reddit Semantic Search (Powered by Word2Vec + FAISS)\n_Disclaimer: Prototype, not affiliated with Reddit Inc._") with gr.Row(): query = gr.Textbox(label="Enter Reddit-style query") subreddit_dropdown = gr.Dropdown(choices=unique_subreddits, label="Choose Subreddit") output = gr.Textbox(label="Matching Comments", lines=10) search_btn = gr.Button("🔍 Search") search_btn.click(fn=search_reddit, inputs=[query, subreddit_dropdown], outputs=output) demo.launch(share=True)