File size: 3,389 Bytes
326bd5f
 
 
 
67779e2
 
 
 
 
 
 
326bd5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import logging
import pickle
from pathlib import Path

from .logger import setup_logging
from .download_vtt import Download
from .vtt_to_txt import Clean
from .embed_transcripts import Embed
from .retrieve_context import Context
from .tokenizer import Tokenizer
from .generate_response import Response
from config import CHANNEL_IDS, MAX_CONTEXT_TOKENS, VTT_DIR, TXT_DIR, CHUNK_FAISS, RETRIEVED_TRANSCRIPTS_FILE, RESPONSE_FILE, \
FILE_PKL, TRANSCRIPTS_PKL, CHUNK_PKL, MODEL, GROQ_API_KEY, SYSTEM_PROMPT, ENCODER

setup_logging()
logger = logging.getLogger(__name__)

def load_text_corpus(txt_dir: Path) -> tuple[list[Path], list[str]]:
    transcripts = []
    file_paths = []
    for file_path in sorted(txt_dir.glob("*.txt")):
        text = file_path.read_text(encoding="utf-8")
        transcripts.append(text)
        file_paths.append(Path(file_path))
    logger.info("Collected %d transcripts", len(file_paths))
    return file_paths, transcripts

def write_retrieved_transcripts(retrieved_transcripts: list[str], file_paths: list[Path]) -> None:
    with RETRIEVED_TRANSCRIPTS_FILE.open("w", encoding="utf-8") as f:
        for i, (path, transcript) in enumerate(zip(file_paths, retrieved_transcripts), start=1):
            video_id = path.stem.split(".")[0]
            f.write(f"Video id: {video_id}\nTranscript {i}:\n{transcript}\n")

def main() -> None:
    # Query
    query = input("Enter query:\n").strip()
    if not query:
        logger.error("Query cannot be empty")
        return

    # Download
    for channel_id in CHANNEL_IDS:
        obj1 = Download(channel_id, VTT_DIR, language = "en")
        obj1.download_channel_subtitles()

    # Clean
    obj2 = Clean(VTT_DIR, TXT_DIR)
    obj2.vtt_to_txt()

    # Get file paths and transcripts
    file_paths, transcripts = load_text_corpus(TXT_DIR)
    with open(FILE_PKL, "wb") as f:
        pickle.dump(file_paths, f)
    with open(TRANSCRIPTS_PKL, "wb") as f:
        pickle.dump(transcripts, f)

    with open(FILE_PKL, "rb") as f:
        file_paths = pickle.load(f)

    # Embed
    obj3 = Embed(TRANSCRIPTS_PKL, CHUNK_FAISS, CHUNK_PKL)
    obj3.embedding()

    # Get context
    obj4 = Context(CHUNK_FAISS, CHUNK_PKL)
    retrieved = obj4.retrieve_chunks(query, top_k=20, retrieve_k=25)
    if not retrieved:
        return
    write_retrieved_transcripts(retrieved, file_paths)

    # Token limit prompt
    obj5 = Tokenizer(MODEL, ENCODER)
    full_context = " ".join(retrieved)
    limit_context = obj5.trim_to_token_limit(full_context, MAX_CONTEXT_TOKENS)
    context_str = " ".join(limit_context.split("\n"))
    logger.info("Full_context: %d tokens, %d words", obj5.count_tokens(full_context), len(full_context.split(" ")), )
    logger.info("Limit_context: %d tokens, %d words", obj5.count_tokens(limit_context), len(limit_context.split(" ")))

    # Response
    obj6 = Response(GROQ_API_KEY, MODEL, ENCODER, SYSTEM_PROMPT)
    response = obj6.generate_response(query, limit_context)
    logging.info("Total number of tokens in prompt: %s", obj5.count_tokens(query + SYSTEM_PROMPT + limit_context))
    write = "\n".join([f"Received query: {query}", f"Context: {context_str}", f"Response: {response}"])
    RESPONSE_FILE.write_text(write, encoding="utf-8")

if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s [%(name)s]: %(message)s", )
    main()