| import pandas as pd |
| import nltk |
| nltk.download('punkt') |
| from nltk.tokenize import sent_tokenize |
| import chromadb |
| from chromadb.utils import embedding_functions |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer |
|
|
| import gradio as gr |
|
|
| import re |
|
|
| |
|
|
| |
| emails = pd.read_csv("./cleaned_data.csv") |
|
|
| |
| client = chromadb.PersistentClient(path="./content") |
|
|
| |
| client = chromadb.Client() |
| collection = client.create_collection("enron_emails") |
|
|
| |
| collection.add( |
| documents=emails["body"].tolist()[:10000], |
| ids=emails["file"].tolist()[:10000], |
| metadatas=[{"source": "enron_emails"}] * len(emails[:10000]), |
| ) |
|
|
|
|
| |
| |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer |
| |
| model = AutoModelForSeq2SeqLM.from_pretrained("varl42/modello42") |
|
|
| |
| tokenizer = AutoTokenizer.from_pretrained("varl42/modello42") |
|
|
| |
|
|
| |
| client = chromadb.Client() |
| collection = client.get_collection("enron_emails") |
|
|
| |
|
|
| def query_collection(query_text): |
| try: |
| |
| response = collection.query(query_texts=[query_text], n_results=2) |
|
|
| |
| if 'documents' in response and len(response['documents']) > 0: |
| |
| documents = response['documents'][0] |
| return "\n\n".join(documents) |
| else: |
| |
| return "No documents found or the response structure is not as expected." |
| except Exception as e: |
| return f"An error occurred while querying: {e}" |
|
|
|
|
| def summarize_documents(text_input): |
| try: |
| |
| inputs = tokenizer(text_input, return_tensors="pt", truncation=True, max_length=512) |
| |
| summary_ids = model.generate(inputs['input_ids'], max_length=512, min_length=125, length_penalty=2.0, num_beams=4, early_stopping=True) |
| summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) |
|
|
| summary = re.sub(r"(\w+)([?!])\s", r"\1\2. ", summary) |
| summary = re.sub(r"([^.?!])(?=\s+[A-Z]|$)", r"\1.", summary) |
| |
| return summary |
| except Exception as e: |
| return f"An error occurred while summarizing: {e}" |
|
|
| def query_then_summarize(query_text, _): |
| try: |
| |
| query_results = query_collection(query_text) |
| |
| return query_results, "" |
| except Exception as e: |
| return f"An error occurred: {e}", "" |
|
|
| def summarize_from_query(_, query_results): |
| try: |
| |
| summary = summarize_documents(query_results) |
| return query_results, summary |
| except Exception as e: |
| return query_results, f"An error occurred while summarizing: {e}" |
|
|
|
|
| |
| |
| |
| with gr.Blocks() as app: |
| with gr.Row(): |
| query_input = gr.Textbox(label="Enter your query") |
| query_button = gr.Button("Query") |
| query_results = gr.Text(label="Query Results", placeholder="Query results will appear here...", interactive=True) |
| summarize_button = gr.Button("Summarize") |
| summary_output = gr.Textbox(label="Summary", placeholder="Summary will appear here...") |
|
|
| query_button.click(query_then_summarize, inputs=[query_input, query_results], outputs=[query_results, summary_output]) |
| summarize_button.click(summarize_from_query, inputs=[query_button, query_results], outputs=[query_results, summary_output]) |
|
|
| app.launch() |
|
|
|
|