AutoReasearcher / app.py
Ani14's picture
Update app.py
04b8d03 verified
raw
history blame
18.2 kB
import os
import streamlit as st
import requests
import datetime
import time
import json
import uuid
from dotenv import load_dotenv
from tavily import TavilyClient
import feedparser
from fuzzywuzzy import fuzz
from fpdf import FPDF
from duckduckgo_search import DDGS
from io import BytesIO
# --- Load API Keys ---
load_dotenv()
OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY")
TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
tavily = TavilyClient(api_key=TAVILY_API_KEY)
# --- Utility Functions ---
def save_session_data():
data = {
"memory_bank": st.session_state.get("memory_bank", []),
"chat_threads": st.session_state.get("chat_threads", {})
}
with open("session_memory.json", "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=4)
def load_session_data():
if os.path.exists("session_memory.json"):
with open("session_memory.json", "r", encoding="utf-8") as f:
data = json.load(f)
st.session_state.memory_bank = data.get("memory_bank", [])
st.session_state.chat_threads = data.get("chat_threads", {})
def call_llm(messages, model="deepseek/deepseek-chat-v3-0324:free", max_tokens=4000, temperature=0.7):
url = "https://openrouter.ai/api/v1/chat/completions"
headers = {
"Authorization": f"Bearer {OPENROUTER_API_KEY}",
"Content-Type": "application/json",
"X-Title": "Deep Research Assistant"
}
data = {
"model": model,
"messages": messages,
"max_tokens": max_tokens,
"temperature": temperature,
"stream": True
}
with requests.post(url, headers=headers, json=data, stream=True) as response:
for line in response.iter_lines():
if line:
decoded = line.decode("utf-8")
if decoded.startswith("data: "):
piece = decoded.replace("data: ", "").strip()
if piece != "[DONE]":
try:
parsed = json.loads(piece)
delta = parsed.get("choices", [{}])[0].get("delta", {})
token = delta.get("content", "")
if token:
yield token
except json.JSONDecodeError:
pass
def get_image_urls(query, max_images=5):
with DDGS() as ddgs:
return [img["image"] for img in ddgs.images(query, max_results=max_images)]
def get_sources(topic, domains=None):
query = topic
if domains:
domain_filters = [d.strip() for d in domains.split(",") if d.strip()]
query += " site:" + " OR site:".join(domain_filters)
response = tavily.search(query=query, search_depth="advanced", max_results=10)
results = []
for r in response.get("results", []):
image_url = r.get("image_url")
if not image_url:
try:
images = get_image_urls(r["title"], max_images=1)
image_url = images[0] if images else None
except:
image_url = None
results.append({
"title": r["title"],
"url": r["url"],
"snippet": r.get("content", ""),
"image_url": image_url,
"source": "web",
"year": extract_year_from_text(r.get("content", ""))
})
return results
def get_arxiv_papers(query):
from urllib.parse import quote_plus
url = f"http://export.arxiv.org/api/query?search_query=all:{quote_plus(query)}&start=0&max_results=5"
feed = feedparser.parse(url)
return [{
"title": e.title,
"summary": e.summary.replace("\n", " ").strip(),
"url": next((l.href for l in e.links if l.type == "application/pdf"), ""),
"source": "arxiv",
"year": int(e.published[:4]) if 'published' in e else 9999
} for e in feed.entries]
def get_semantic_papers(query):
try:
url = "https://api.semanticscholar.org/graph/v1/paper/search"
params = {"query": query, "limit": 5, "fields": "title,abstract,url,year"}
response = requests.get(url, params=params)
papers = response.json().get("data", [])
return [{
"title": p.get("title"),
"summary": p.get("abstract", "No abstract available"),
"url": p.get("url"),
"source": "semantic",
"year": p.get("year", 9999)
} for p in papers]
except:
return []
def extract_year_from_text(text):
import re
years = re.findall(r"\b(19|20)\d{2}\b", text)
return int(years[0]) if years else 9999
def merge_duplicates(entries):
unique = []
seen_titles = []
for entry in entries:
if all(fuzz.token_set_ratio(entry['title'], seen) < 90 for seen in seen_titles):
unique.append(entry)
seen_titles.append(entry['title'])
return unique
def sort_sources_chronologically(sources):
return sorted(sources, key=lambda s: s.get("year", 9999))
def build_chronological_progression(sources):
timeline = {}
for s in sources:
year = s.get("year", 9999)
if year != 9999:
if year not in timeline:
timeline[year] = []
timeline[year].append(f"- {s['title']}")
summary = ""
for year in sorted(timeline.keys()):
entries = "\n".join(timeline[year])
summary += f"**{year}**\n{entries}\n\n"
return summary.strip()
def download_threads_as_pdf(chat_threads):
pdf = FPDF()
pdf.add_page()
pdf.set_auto_page_break(auto=True, margin=15)
pdf.set_font("Arial", size=12)
for tid, chats in chat_threads.items():
pdf.cell(0, 10, f"Thread {tid[:8]}", ln=True)
for msg in chats:
role = "You" if msg["role"] == "user" else "Assistant"
text = f"{role}: {msg['content']}"
try:
text = text.encode('latin-1').decode('latin-1')
except UnicodeEncodeError:
text = text.encode('latin-1', 'replace').decode('latin-1')
pdf.multi_cell(0, 10, text)
pdf.ln(5)
pdf_output = BytesIO()
pdf_bytes = pdf.output(dest='S').encode('latin-1')
pdf_output.write(pdf_bytes)
pdf_output.seek(0)
return pdf_output
# --- Streamlit UI Setup ---
st.set_page_config(page_title="🧠 Deep Research Assistant 4.0", layout="centered")
if "memory_bank" not in st.session_state:
st.session_state.memory_bank = []
if "chat_threads" not in st.session_state:
st.session_state.chat_threads = {}
if "current_thread_id" not in st.session_state:
st.session_state.current_thread_id = None
load_session_data()
# --- Sidebar ---
with st.sidebar:
st.image("https://raw.githubusercontent.com/mk-gurucharan/streamlit-deep-research/main/deep_research_logo.png", use_container_width=True)
st.markdown("## πŸ” Start New Research")
topic = st.text_input("🧠 Topic")
report_type = st.selectbox("πŸ“„ Report Type", ["Summary", "Detailed Report", "Thorough Academic Research"])
tone = st.selectbox("🎯 Tone", ["Objective", "Persuasive", "Narrative"])
source_type = st.selectbox("πŸ“š Sources", ["Web Only", "Academic Only", "Hybrid"])
custom_domains = st.text_input("🌐 Optional Domains", placeholder="forbes.com, mit.edu")
research_button = st.button("πŸš€ Run Deep Research", use_container_width=True)
st.markdown("---")
st.markdown("Made with ❀️ by Cutie AI ✨")
# --- Main Title ---
st.title("πŸŒ™ Deep Research Assistant 4.0")
st.markdown("Where serious research meets serious style. πŸ§ πŸ’–")
st.divider()
# Continue from here for Main Research Section...
# --- Show Web Images from Topic ---
if topic and research_button:
st.subheader("πŸ–Ό Related Images from the Web")
try:
topic_images = get_image_urls(topic, max_images=6)
if topic_images:
img_cols = st.columns(3)
for idx, img_url in enumerate(topic_images):
with img_cols[idx % 3]:
st.image(img_url, use_container_width=True)
else:
st.info("No images found for this topic.")
except Exception as e:
st.warning(f"Couldn't load topic images. ({e})")
# --- Main Research Section ---
if research_button and topic:
try:
with st.status("πŸ”Ž Gathering sources..."):
all_sources = []
if source_type in ["Web Only", "Hybrid"]:
all_sources += get_sources(topic, custom_domains) if custom_domains.strip() else get_sources(topic)
if source_type in ["Academic Only", "Hybrid"]:
all_sources += get_arxiv_papers(topic)
all_sources += get_semantic_papers(topic)
if not all_sources:
raise ValueError("❌ No sources found.")
merged = merge_duplicates(all_sources)
merged = sort_sources_chronologically(merged)
chronological_progress = build_chronological_progression(merged)
previous_learnings = "\n\n".join(st.session_state.memory_bank[-5:])
citations = [f"- {s['title']} ({s['year']}) [{s['source']}]({s['url']})" for s in merged]
sources_text = "\n".join([
f"- [{s['title']}]({s['url']}) ({s['year']})\n> {s.get('snippet', s.get('summary', ''))[:300]}..."
for s in merged
])
length_instruction = {
"Summary": "Keep it concise, under 300 words.",
"Detailed Report": "Write 500-700 words with critical insights.",
"Thorough Academic Research": "Craft a full academic paper >1000 words."
}[report_type]
# Create New Thread
thread_id = str(uuid.uuid4())
st.session_state.current_thread_id = thread_id
st.session_state.chat_threads[thread_id] = []
prompt = f"""
Use past learnings:
{previous_learnings}
New Topic:
{topic}
Writing:
{tone} tone, {length_instruction}
Timeline:
{chronological_progress}
Sources:
{sources_text}
Citations:
{chr(10).join(citations)}
"""
# --- Generate Report ---
st.subheader(f"πŸ“ {report_type} on '{topic}'")
output_placeholder = st.empty()
final_output = ""
for chunk in call_llm([{"role": "user", "content": prompt}]):
final_output += chunk
output_placeholder.markdown(final_output, unsafe_allow_html=True)
st.session_state.memory_bank.append(final_output)
st.session_state.chat_threads[thread_id].append({"role": "assistant", "content": final_output})
save_session_data()
except Exception as e:
st.error(f"❌ Error: {e}")
# --- Chat Threads Section ---
st.divider()
st.subheader("πŸ“‚ Your Research Threads")
user_avatar = "https://cdn-icons-png.flaticon.com/512/9131/9131529.png"
assistant_avatar = "https://cdn-icons-png.flaticon.com/512/4712/4712107.png"
for tid, chats in st.session_state.chat_threads.items():
with st.container():
with st.expander(f"🧡 Thread {tid[:8]}", expanded=False):
for msg in chats:
avatar = user_avatar if msg['role'] == 'user' else assistant_avatar
bubble_color = "#DCF8C6" if msg['role'] == 'user' else "#F0F0F0"
align = "flex-end" if msg['role'] == 'user' else "flex-start"
st.markdown(f"""
<div style="display: flex; justify-content: {align}; margin-bottom: 10px;">
<img src="{avatar}" width="30" style="margin-right: 10px; border-radius: 50%;">
<div style="background-color: {bubble_color}; padding: 10px 15px; border-radius: 10px; max-width: 70%;">
{msg['content']}
</div>
</div>
""", unsafe_allow_html=True)
followup = st.text_input(f"πŸ’¬ Continue Thread {tid[:8]}:", key=f"followup_{tid}")
if st.button(f"Ask Follow-up {tid}", key=f"button_{tid}"):
if followup:
with st.spinner("πŸ€– Assistant is typing..."):
response = ""
for chunk in call_llm(st.session_state.chat_threads[tid] + [{"role": "user", "content": followup}], max_tokens=2000):
response += chunk
st.markdown(response)
st.session_state.chat_threads[tid].append({"role": "user", "content": followup})
st.session_state.chat_threads[tid].append({"role": "assistant", "content": response})
save_session_data()
st.rerun()
# --- Download All Threads Section ---
if st.session_state.chat_threads:
st.divider()
st.subheader("πŸ“₯ Export Your Work")
pdf_file = download_threads_as_pdf(st.session_state.chat_threads)
st.download_button("πŸ“₯ Download All Threads as PDF", data=pdf_file, file_name="Research_Threads.pdf", mime="application/pdf", use_container_width=True)
# 🧠 Initialize session state
if "last_report" not in st.session_state:
st.session_state["last_report"] = ""
if "follow_up_input" not in st.session_state:
st.session_state["follow_up_input"] = ""
if "methodology_notes" not in st.session_state:
st.session_state["methodology_notes"] = ""
if "chat_history" not in st.session_state:
st.session_state["chat_history"] = []
# --- Methodology Recommender ---
st.divider()
st.subheader("πŸ§ͺ Methodology Recommender")
if st.button("🧠 Suggest Research Methodologies"):
if st.session_state["last_report"]:
try:
method_prompt = [
{"role": "system", "content": "You are a research advisor. Based on the following research report, suggest suitable research methodologies (quantitative, qualitative, ML/AI techniques, etc.). Be concise and give bullet-point suggestions."},
{"role": "user", "content": st.session_state["last_report"]}
]
method_output = ""
method_box = st.empty()
for chunk in call_llm(method_prompt):
method_output += chunk
method_box.markdown(method_output, unsafe_allow_html=True)
# βœ… Store methodology context for follow-up
st.session_state["methodology_notes"] = method_output
except Exception as e:
st.error(f"❌ Methodology suggestion failed: {e}")
else:
st.warning("⚠️ Generate the research report first.")
# --- Follow-up Q&A (Contextual to Report + Methodology) ---
st.divider()
st.subheader("πŸ’¬ Follow-up Q&A")
followup = st.text_input("Ask a follow-up question:", key="follow_up_input")
if st.button("Ask"):
if followup:
try:
context_intro = (
"Below is a research report followed by methodology suggestions.\n"
"Use both to answer the user's follow-up question."
)
combined_context = f"{context_intro}\n\n=== Report ===\n{st.session_state['last_report']}\n\n=== Methodology ===\n{st.session_state['methodology_notes']}"
chat = st.session_state.chat_history + [
{"role": "system", "content": "You are an academic research assistant."},
{"role": "user", "content": combined_context},
{"role": "user", "content": followup}
]
response = ""
for chunk in call_llm(chat, max_tokens=1500):
response += chunk
st.session_state.chat_history.append({"role": "user", "content": followup})
st.session_state.chat_history.append({"role": "assistant", "content": response})
st.markdown(response)
except Exception as e:
st.error(f"Follow-up error: {e}")
# --- Paper Upload for Review & Improvement ---
st.divider()
st.subheader("πŸ“€ Upload Your Paper for Feedback")
uploaded_file = st.file_uploader("Upload your research paper (.pdf or .txt)", type=["pdf", "txt"])
if uploaded_file and st.button("🧠 Analyze and Suggest Improvements"):
try:
def extract_text_from_file(file):
if file.name.endswith(".pdf"):
from PyPDF2 import PdfReader
reader = PdfReader(file)
return "\n".join([page.extract_text() for page in reader.pages if page.extract_text()])
elif file.name.endswith(".txt"):
return file.read().decode("utf-8")
return ""
paper_text = extract_text_from_file(uploaded_file)
if not paper_text or len(paper_text.strip()) < 100:
st.warning("⚠️ The uploaded paper seems empty or too short to analyze.")
else:
feedback_prompt = [
{"role": "system", "content": "You are an expert academic advisor."},
{"role": "user", "content": f"""I have written the following research paper. Please analyze it and provide detailed suggestions on:
- Areas where the paper is weak or unclear
- How to improve the novelty or originality
- Structural improvements or better ways to present arguments
Be honest and constructive. Here's the full text:
\"\"\"{paper_text}\"\"\""""}
]
with st.status("πŸ”Ž Analyzing your paper..."):
improvement_output = ""
feedback_box = st.empty()
for chunk in call_llm(feedback_prompt, max_tokens=2500):
improvement_output += chunk
feedback_box.markdown(improvement_output, unsafe_allow_html=True)
except Exception as e:
st.error(f"❌ Error while analyzing paper: {e}")
# --- Optional: View Chat History ---
with st.expander("πŸ“œ View Full Chat History", expanded=False):
for msg in st.session_state.chat_history:
role = msg["role"]
prefix = "πŸ‘€ You" if role == "user" else "πŸ€– Assistant"
st.markdown(f"**{prefix}:** {msg['content']}")