Anshul Prasad commited on
Commit
326bd5f
·
1 Parent(s): cc7a7b6

name change.

Browse files

structural change for compatibility with oop.
easy logic are colapsed inside main()

Files changed (1) hide show
  1. main.py → src/__init__.py +49 -70
main.py → src/__init__.py RENAMED
@@ -2,107 +2,86 @@ import logging
2
  import pickle
3
  from pathlib import Path
4
 
5
- from src.logger import setup_logging
6
- from src.download_vtt import download_channel_subtitles
7
- from src.vtt_to_txt import vtt_to_txt
8
- from src.retrieve_context import retrieve_transcripts
9
- from src.generate_response import generate_response
10
- from src.embed_transcripts import embedding
11
- from src.tokenizer import trim_to_token_limit, count_tokens
12
- from config import CHANNEL_URLS, MAX_CONTEXT_TOKENS, VTT_DIR, TXT_DIR, TRANSCRIPT_INDEX, RETRIEVED_TRANSCRIPTS_FILE, RESPONSE_FILE, \
13
- FILE_PATHS, TRANSCRIPTS, CHUNKS_PKL
14
 
15
  setup_logging()
16
-
17
  logger = logging.getLogger(__name__)
18
 
19
- def stage_download() -> None:
20
- for channel_url in CHANNEL_URLS:
21
- try:
22
- download_channel_subtitles(channel_url, VTT_DIR, language="en")
23
- except Exception:
24
- logger.exception("Failed to download subtitles for %s", channel_url)
25
-
26
-
27
- def stage_persist(file_paths, transcripts) -> None:
28
- with open(FILE_PATHS, "wb") as f:
29
- pickle.dump(file_paths, f)
30
- with open(TRANSCRIPTS, "wb") as f:
31
- pickle.dump(transcripts, f)
32
-
33
-
34
- def stage_retrieve(query: str, k: int = 20) -> list[str]:
35
- results = retrieve_transcripts(query, retrieve_k=k)
36
- if not results:
37
- logger.warning("No relevant transcripts found")
38
- return results
39
-
40
  def load_text_corpus(txt_dir: Path) -> tuple[list[Path], list[str]]:
41
- """Load a corpus of text files from a directory."""
42
  transcripts = []
43
  file_paths = []
44
-
45
  for file_path in sorted(txt_dir.glob("*.txt")):
46
  text = file_path.read_text(encoding="utf-8")
47
  transcripts.append(text)
48
- file_paths.append(file_path)
49
-
50
  logger.info("Collected %d transcripts", len(file_paths))
51
  return file_paths, transcripts
52
 
53
  def write_retrieved_transcripts(retrieved_transcripts: list[str], file_paths: list[Path]) -> None:
54
- try:
55
- with RETRIEVED_TRANSCRIPTS_FILE.open("w", encoding="utf-8") as f:
56
- for i, (path, transcript) in enumerate(zip(file_paths, retrieved_transcripts), start=1):
57
- video_id = path.stem.split(".")[0]
58
- f.write(f"Video id: {video_id}\nTranscript {i}:\n{transcript}\n")
59
- except Exception:
60
- logger.exception("Failed to write retrieved transcripts")
61
-
62
-
63
- def write_response(response):
64
- try:
65
- RESPONSE_FILE.write_text(response, encoding="utf-8")
66
- logger.info("Response written to %s", RESPONSE_FILE)
67
-
68
- except Exception:
69
- logger.exception("Failed to write response")
70
-
71
 
72
  def main() -> None:
 
73
  query = input("Enter query:\n").strip()
74
  if not query:
75
  logger.error("Query cannot be empty")
76
  return
77
 
78
- stage_download()
79
- vtt_to_txt(VTT_DIR, TXT_DIR)
 
 
 
 
 
 
 
 
80
  file_paths, transcripts = load_text_corpus(TXT_DIR)
81
- stage_persist(file_paths, transcripts)
 
 
 
82
 
83
- with open(FILE_PATHS, "rb") as f:
84
  file_paths = pickle.load(f)
85
- with open(TRANSCRIPTS, "rb") as f:
86
- transcripts = pickle.load(f)
87
- file_paths = [Path(p) for p in file_paths]
88
 
89
- embedding(transcripts, TRANSCRIPT_INDEX, CHUNKS_PKL)
 
 
90
 
91
- retrieved = stage_retrieve(query)
 
 
92
  if not retrieved:
93
  return
94
-
95
  write_retrieved_transcripts(retrieved, file_paths)
 
 
 
96
  full_context = " ".join(retrieved)
97
- limit_context = trim_to_token_limit(full_context, MAX_CONTEXT_TOKENS)
98
  context_str = " ".join(limit_context.split("\n"))
99
-
100
- response = generate_response(query, limit_context)
101
- write_response("\n".join([f"Received query: {query}", f"Context: {context_str}", f"Response: {response}"]))
102
-
103
- logger.info("Full_context: %d tokens, %d words", count_tokens(full_context), len(full_context.split(" ")), )
104
- logger.info("Limit_context: %d tokens, %d words", count_tokens(limit_context), len(limit_context.split(" ")))
105
-
 
 
106
 
107
  if __name__ == "__main__":
108
  logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s [%(name)s]: %(message)s", )
 
2
  import pickle
3
  from pathlib import Path
4
 
5
+ from logger import setup_logging
6
+ from download_vtt import Download
7
+ from vtt_to_txt import Clean
8
+ from embed_transcripts import Embed
9
+ from retrieve_context import Context
10
+ from tokenizer import Tokenizer
11
+ from generate_response import Response
12
+ from config import CHANNEL_IDS, MAX_CONTEXT_TOKENS, VTT_DIR, TXT_DIR, CHUNK_FAISS, RETRIEVED_TRANSCRIPTS_FILE, RESPONSE_FILE, \
13
+ FILE_PKL, TRANSCRIPTS_PKL, CHUNK_PKL, MODEL, GROQ_API_KEY, SYSTEM_PROMPT, ENCODER
14
 
15
  setup_logging()
 
16
  logger = logging.getLogger(__name__)
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  def load_text_corpus(txt_dir: Path) -> tuple[list[Path], list[str]]:
 
19
  transcripts = []
20
  file_paths = []
 
21
  for file_path in sorted(txt_dir.glob("*.txt")):
22
  text = file_path.read_text(encoding="utf-8")
23
  transcripts.append(text)
24
+ file_paths.append(Path(file_path))
 
25
  logger.info("Collected %d transcripts", len(file_paths))
26
  return file_paths, transcripts
27
 
28
  def write_retrieved_transcripts(retrieved_transcripts: list[str], file_paths: list[Path]) -> None:
29
+ with RETRIEVED_TRANSCRIPTS_FILE.open("w", encoding="utf-8") as f:
30
+ for i, (path, transcript) in enumerate(zip(file_paths, retrieved_transcripts), start=1):
31
+ video_id = path.stem.split(".")[0]
32
+ f.write(f"Video id: {video_id}\nTranscript {i}:\n{transcript}\n")
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  def main() -> None:
35
+ # Query
36
  query = input("Enter query:\n").strip()
37
  if not query:
38
  logger.error("Query cannot be empty")
39
  return
40
 
41
+ # Download
42
+ for channel_id in CHANNEL_IDS:
43
+ obj1 = Download(channel_id, VTT_DIR, language = "en")
44
+ obj1.download_channel_subtitles()
45
+
46
+ # Clean
47
+ obj2 = Clean(VTT_DIR, TXT_DIR)
48
+ obj2.vtt_to_txt()
49
+
50
+ # Get file paths and transcripts
51
  file_paths, transcripts = load_text_corpus(TXT_DIR)
52
+ with open(FILE_PKL, "wb") as f:
53
+ pickle.dump(file_paths, f)
54
+ with open(TRANSCRIPTS_PKL, "wb") as f:
55
+ pickle.dump(transcripts, f)
56
 
57
+ with open(FILE_PKL, "rb") as f:
58
  file_paths = pickle.load(f)
 
 
 
59
 
60
+ # Embed
61
+ obj3 = Embed(TRANSCRIPTS_PKL, CHUNK_FAISS, CHUNK_PKL)
62
+ obj3.embedding()
63
 
64
+ # Get context
65
+ obj4 = Context(CHUNK_FAISS, CHUNK_PKL)
66
+ retrieved = obj4.retrieve_chunks(query, top_k=20, retrieve_k=25)
67
  if not retrieved:
68
  return
 
69
  write_retrieved_transcripts(retrieved, file_paths)
70
+
71
+ # Token limit prompt
72
+ obj5 = Tokenizer(MODEL, ENCODER)
73
  full_context = " ".join(retrieved)
74
+ limit_context = obj5.trim_to_token_limit(full_context, MAX_CONTEXT_TOKENS)
75
  context_str = " ".join(limit_context.split("\n"))
76
+ logger.info("Full_context: %d tokens, %d words", obj5.count_tokens(full_context), len(full_context.split(" ")), )
77
+ logger.info("Limit_context: %d tokens, %d words", obj5.count_tokens(limit_context), len(limit_context.split(" ")))
78
+
79
+ # Response
80
+ obj6 = Response(GROQ_API_KEY, MODEL, ENCODER, SYSTEM_PROMPT)
81
+ response = obj6.generate_response(query, limit_context)
82
+ logging.info("Total number of tokens in prompt: %s", obj5.count_tokens(query + SYSTEM_PROMPT + limit_context))
83
+ write = "\n".join([f"Received query: {query}", f"Context: {context_str}", f"Response: {response}"])
84
+ RESPONSE_FILE.write_text(write, encoding="utf-8")
85
 
86
  if __name__ == "__main__":
87
  logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s [%(name)s]: %(message)s", )