Anshul Prasad commited on
Commit ·
326bd5f
1
Parent(s): cc7a7b6
name change.
Browse filesstructural change for compatibility with oop.
easy logic are colapsed inside main()
- main.py → src/__init__.py +49 -70
main.py → src/__init__.py
RENAMED
|
@@ -2,107 +2,86 @@ import logging
|
|
| 2 |
import pickle
|
| 3 |
from pathlib import Path
|
| 4 |
|
| 5 |
-
from
|
| 6 |
-
from
|
| 7 |
-
from
|
| 8 |
-
from
|
| 9 |
-
from
|
| 10 |
-
from
|
| 11 |
-
from
|
| 12 |
-
from config import
|
| 13 |
-
|
| 14 |
|
| 15 |
setup_logging()
|
| 16 |
-
|
| 17 |
logger = logging.getLogger(__name__)
|
| 18 |
|
| 19 |
-
def stage_download() -> None:
|
| 20 |
-
for channel_url in CHANNEL_URLS:
|
| 21 |
-
try:
|
| 22 |
-
download_channel_subtitles(channel_url, VTT_DIR, language="en")
|
| 23 |
-
except Exception:
|
| 24 |
-
logger.exception("Failed to download subtitles for %s", channel_url)
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
def stage_persist(file_paths, transcripts) -> None:
|
| 28 |
-
with open(FILE_PATHS, "wb") as f:
|
| 29 |
-
pickle.dump(file_paths, f)
|
| 30 |
-
with open(TRANSCRIPTS, "wb") as f:
|
| 31 |
-
pickle.dump(transcripts, f)
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
def stage_retrieve(query: str, k: int = 20) -> list[str]:
|
| 35 |
-
results = retrieve_transcripts(query, retrieve_k=k)
|
| 36 |
-
if not results:
|
| 37 |
-
logger.warning("No relevant transcripts found")
|
| 38 |
-
return results
|
| 39 |
-
|
| 40 |
def load_text_corpus(txt_dir: Path) -> tuple[list[Path], list[str]]:
|
| 41 |
-
"""Load a corpus of text files from a directory."""
|
| 42 |
transcripts = []
|
| 43 |
file_paths = []
|
| 44 |
-
|
| 45 |
for file_path in sorted(txt_dir.glob("*.txt")):
|
| 46 |
text = file_path.read_text(encoding="utf-8")
|
| 47 |
transcripts.append(text)
|
| 48 |
-
file_paths.append(file_path)
|
| 49 |
-
|
| 50 |
logger.info("Collected %d transcripts", len(file_paths))
|
| 51 |
return file_paths, transcripts
|
| 52 |
|
| 53 |
def write_retrieved_transcripts(retrieved_transcripts: list[str], file_paths: list[Path]) -> None:
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
f.write(f"Video id: {video_id}\nTranscript {i}:\n{transcript}\n")
|
| 59 |
-
except Exception:
|
| 60 |
-
logger.exception("Failed to write retrieved transcripts")
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
def write_response(response):
|
| 64 |
-
try:
|
| 65 |
-
RESPONSE_FILE.write_text(response, encoding="utf-8")
|
| 66 |
-
logger.info("Response written to %s", RESPONSE_FILE)
|
| 67 |
-
|
| 68 |
-
except Exception:
|
| 69 |
-
logger.exception("Failed to write response")
|
| 70 |
-
|
| 71 |
|
| 72 |
def main() -> None:
|
|
|
|
| 73 |
query = input("Enter query:\n").strip()
|
| 74 |
if not query:
|
| 75 |
logger.error("Query cannot be empty")
|
| 76 |
return
|
| 77 |
|
| 78 |
-
|
| 79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
file_paths, transcripts = load_text_corpus(TXT_DIR)
|
| 81 |
-
|
|
|
|
|
|
|
|
|
|
| 82 |
|
| 83 |
-
with open(
|
| 84 |
file_paths = pickle.load(f)
|
| 85 |
-
with open(TRANSCRIPTS, "rb") as f:
|
| 86 |
-
transcripts = pickle.load(f)
|
| 87 |
-
file_paths = [Path(p) for p in file_paths]
|
| 88 |
|
| 89 |
-
|
|
|
|
|
|
|
| 90 |
|
| 91 |
-
|
|
|
|
|
|
|
| 92 |
if not retrieved:
|
| 93 |
return
|
| 94 |
-
|
| 95 |
write_retrieved_transcripts(retrieved, file_paths)
|
|
|
|
|
|
|
|
|
|
| 96 |
full_context = " ".join(retrieved)
|
| 97 |
-
limit_context = trim_to_token_limit(full_context, MAX_CONTEXT_TOKENS)
|
| 98 |
context_str = " ".join(limit_context.split("\n"))
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
|
|
|
|
|
|
| 106 |
|
| 107 |
if __name__ == "__main__":
|
| 108 |
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s [%(name)s]: %(message)s", )
|
|
|
|
| 2 |
import pickle
|
| 3 |
from pathlib import Path
|
| 4 |
|
| 5 |
+
from logger import setup_logging
|
| 6 |
+
from download_vtt import Download
|
| 7 |
+
from vtt_to_txt import Clean
|
| 8 |
+
from embed_transcripts import Embed
|
| 9 |
+
from retrieve_context import Context
|
| 10 |
+
from tokenizer import Tokenizer
|
| 11 |
+
from generate_response import Response
|
| 12 |
+
from config import CHANNEL_IDS, MAX_CONTEXT_TOKENS, VTT_DIR, TXT_DIR, CHUNK_FAISS, RETRIEVED_TRANSCRIPTS_FILE, RESPONSE_FILE, \
|
| 13 |
+
FILE_PKL, TRANSCRIPTS_PKL, CHUNK_PKL, MODEL, GROQ_API_KEY, SYSTEM_PROMPT, ENCODER
|
| 14 |
|
| 15 |
setup_logging()
|
|
|
|
| 16 |
logger = logging.getLogger(__name__)
|
| 17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
def load_text_corpus(txt_dir: Path) -> tuple[list[Path], list[str]]:
|
|
|
|
| 19 |
transcripts = []
|
| 20 |
file_paths = []
|
|
|
|
| 21 |
for file_path in sorted(txt_dir.glob("*.txt")):
|
| 22 |
text = file_path.read_text(encoding="utf-8")
|
| 23 |
transcripts.append(text)
|
| 24 |
+
file_paths.append(Path(file_path))
|
|
|
|
| 25 |
logger.info("Collected %d transcripts", len(file_paths))
|
| 26 |
return file_paths, transcripts
|
| 27 |
|
| 28 |
def write_retrieved_transcripts(retrieved_transcripts: list[str], file_paths: list[Path]) -> None:
|
| 29 |
+
with RETRIEVED_TRANSCRIPTS_FILE.open("w", encoding="utf-8") as f:
|
| 30 |
+
for i, (path, transcript) in enumerate(zip(file_paths, retrieved_transcripts), start=1):
|
| 31 |
+
video_id = path.stem.split(".")[0]
|
| 32 |
+
f.write(f"Video id: {video_id}\nTranscript {i}:\n{transcript}\n")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
def main() -> None:
|
| 35 |
+
# Query
|
| 36 |
query = input("Enter query:\n").strip()
|
| 37 |
if not query:
|
| 38 |
logger.error("Query cannot be empty")
|
| 39 |
return
|
| 40 |
|
| 41 |
+
# Download
|
| 42 |
+
for channel_id in CHANNEL_IDS:
|
| 43 |
+
obj1 = Download(channel_id, VTT_DIR, language = "en")
|
| 44 |
+
obj1.download_channel_subtitles()
|
| 45 |
+
|
| 46 |
+
# Clean
|
| 47 |
+
obj2 = Clean(VTT_DIR, TXT_DIR)
|
| 48 |
+
obj2.vtt_to_txt()
|
| 49 |
+
|
| 50 |
+
# Get file paths and transcripts
|
| 51 |
file_paths, transcripts = load_text_corpus(TXT_DIR)
|
| 52 |
+
with open(FILE_PKL, "wb") as f:
|
| 53 |
+
pickle.dump(file_paths, f)
|
| 54 |
+
with open(TRANSCRIPTS_PKL, "wb") as f:
|
| 55 |
+
pickle.dump(transcripts, f)
|
| 56 |
|
| 57 |
+
with open(FILE_PKL, "rb") as f:
|
| 58 |
file_paths = pickle.load(f)
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
+
# Embed
|
| 61 |
+
obj3 = Embed(TRANSCRIPTS_PKL, CHUNK_FAISS, CHUNK_PKL)
|
| 62 |
+
obj3.embedding()
|
| 63 |
|
| 64 |
+
# Get context
|
| 65 |
+
obj4 = Context(CHUNK_FAISS, CHUNK_PKL)
|
| 66 |
+
retrieved = obj4.retrieve_chunks(query, top_k=20, retrieve_k=25)
|
| 67 |
if not retrieved:
|
| 68 |
return
|
|
|
|
| 69 |
write_retrieved_transcripts(retrieved, file_paths)
|
| 70 |
+
|
| 71 |
+
# Token limit prompt
|
| 72 |
+
obj5 = Tokenizer(MODEL, ENCODER)
|
| 73 |
full_context = " ".join(retrieved)
|
| 74 |
+
limit_context = obj5.trim_to_token_limit(full_context, MAX_CONTEXT_TOKENS)
|
| 75 |
context_str = " ".join(limit_context.split("\n"))
|
| 76 |
+
logger.info("Full_context: %d tokens, %d words", obj5.count_tokens(full_context), len(full_context.split(" ")), )
|
| 77 |
+
logger.info("Limit_context: %d tokens, %d words", obj5.count_tokens(limit_context), len(limit_context.split(" ")))
|
| 78 |
+
|
| 79 |
+
# Response
|
| 80 |
+
obj6 = Response(GROQ_API_KEY, MODEL, ENCODER, SYSTEM_PROMPT)
|
| 81 |
+
response = obj6.generate_response(query, limit_context)
|
| 82 |
+
logging.info("Total number of tokens in prompt: %s", obj5.count_tokens(query + SYSTEM_PROMPT + limit_context))
|
| 83 |
+
write = "\n".join([f"Received query: {query}", f"Context: {context_str}", f"Response: {response}"])
|
| 84 |
+
RESPONSE_FILE.write_text(write, encoding="utf-8")
|
| 85 |
|
| 86 |
if __name__ == "__main__":
|
| 87 |
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s [%(name)s]: %(message)s", )
|