odyssey-rag / RAG_core /ner_and_summarizer.py
rikodon72's picture
Initial deployment with Flask + Streamlit
f6e3dfb
import json
import torch
from transformers import pipeline
import spacy
import config
def add_characters_and_summary(episodes):
try:
nlp = spacy.load(config.SPACY_MODEL)
except OSError:
print(f"{config.SPACY_MODEL} not found. Downloading...")
spacy.cli.download(config.SPACY_MODEL)
nlp = spacy.load(config.SPACY_MODEL)
for episode in episodes:
doc = nlp(episode['episode_text'])
ner_names = [ent.text.strip() for ent in doc.ents if ent.label_ in {"PERSON", "ORG", "NORP"} and len(ent.text.strip()) > 1]
propn_names = [token.text.strip() for token in doc if token.pos_ == "PROPN" and token.is_title and len(token.text.strip()) > 1]
episode['main_characters'] = list(set(ner_names + propn_names))
device = "cuda" if torch.cuda.is_available() else "cpu"
generator = pipeline(
"text-generation",
model=config.SUMMARIZER_MODEL_NAME,
device=device,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
max_new_tokens=55
)
def summarize_text(text):
messages = [
{"role": "system", "content": "You are a helpful assistant that summarizes ancient Greek epic poetry accurately you cannot say output 'In the Odyssey.' and only output the summary in 50 tokens"},
{"role": "user", "content": f"Summarize the following text from Homer's Odyssey in 50 tokens:\n\n{text}"}
]
prompt = generator.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
try:
outputs = generator(prompt, max_new_tokens=55, do_sample=False, return_full_text=False, pad_token_id=generator.tokenizer.eos_token_id)
summary = outputs[0]['generated_text'].strip().replace("User:", "").replace("Assistant:", "").strip()
return summary
except Exception as e:
print(f"Summarization error: {e}")
return "Summary generation failed."
for episode in episodes:
episode['summary'] = summarize_text(episode['episode_text'])
# Format metadata string
def format_episode_metadata_string(episode: dict) -> str:
main_characters_str = ", ".join(episode['main_characters'])
return f"This scene occurs in Book {episode['book_id']} and involves {main_characters_str}. {episode['book_topic']}"
for episode in episodes:
episode['metadata_string'] = format_episode_metadata_string(episode)
# Clean up
if 'book_id_roman' in episode: del episode['book_id_roman']
if 'book_topic' in episode: del episode['book_topic']
with open(config.OUTPUT_JSON, 'w', encoding='utf-8') as f:
json.dump(episodes, f, indent=2, ensure_ascii=False)
return episodes