AI-Research-Lab / app.py
Mohit0708's picture
Update app.py
74238a2 verified
# ==========================================
# 1. INITIAL SETUP & LIBRARIES
# ==========================================
import os
import json
import uuid
import base64
import whisper
import pymupdf4llm
import gradio as gr
from datetime import datetime
from huggingface_hub import InferenceClient
from langchain_text_splitters import MarkdownTextSplitter
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import Chroma
# ==========================================
# 2. CONNECT TO AI APIS (OpenAI-Compatible)
# ==========================================
print("โณ Connecting to Hugging Face APIs...")
HF_TOKEN = os.environ.get("HF_TOKEN")
if not HF_TOKEN:
print("โš ๏ธ WARNING: HF_TOKEN not found! The AI will not be able to generate responses.")
# Initialize a single, generic client
# We do NOT bind the model name here to prevent the "text-generation" tagging error
hf_client = InferenceClient(api_key=HF_TOKEN)
# Local Embeddings & Whisper
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
whisper_model = whisper.load_model("base")
print("โœ… APIs and Local Models Loaded Successfully!")
# ==========================================
# 3. GLOBAL STATE & HELPERS
# ==========================================
main_paper_retriever = None
brainstorm_retriever = None
main_extracted_images = []
chat_history_file = "research_lab_history.json"
if not os.path.exists(chat_history_file):
with open(chat_history_file, "w") as f: json.dump([], f)
def save_to_json(user_msg, combined_ans, mode):
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
entry = {"timestamp": timestamp, "mode": mode, "user": user_msg, "assistant": combined_ans}
try:
with open(chat_history_file, "r") as f: history = json.load(f)
except: history = []
history.append(entry)
with open(chat_history_file, "w") as f: json.dump(history, f, indent=4)
def process_pdf_to_markdown(pdf_path, extract_images=True):
global main_extracted_images
output_image_dir = "extracted_images"
if extract_images:
if os.path.exists(output_image_dir):
for f in os.listdir(output_image_dir): os.remove(os.path.join(output_image_dir, f))
else:
os.makedirs(output_image_dir, exist_ok=True)
try:
if extract_images:
md_text = pymupdf4llm.to_markdown(pdf_path, write_images=True, image_path=output_image_dir, image_format="png")
main_extracted_images = [os.path.join(output_image_dir, f) for f in os.listdir(output_image_dir) if f.endswith(('.png', '.jpg'))]
main_extracted_images.sort()
else:
md_text = pymupdf4llm.to_markdown(pdf_path, write_images=False)
return md_text
except Exception as e:
return ""
def process_main_paper(file_obj):
global main_paper_retriever
main_paper_retriever = None
if file_obj is None: return "โš ๏ธ No file uploaded."
try:
unique_id = f"main_{uuid.uuid4().hex[:8]}"
md_content = process_pdf_to_markdown(file_obj.name, extract_images=True)
splitter = MarkdownTextSplitter(chunk_size=1000, chunk_overlap=200)
chunks = splitter.create_documents([md_content])
vectordb = Chroma.from_documents(documents=chunks, embedding=embedding_model, collection_name=unique_id)
main_paper_retriever = vectordb.as_retriever(search_kwargs={"k": 3})
return f"โœ… Main Paper Ready!\n๐Ÿ“˜ Text: Indexed\n๐Ÿ‘๏ธ Images: {len(main_extracted_images)} Extracted"
except Exception as e:
return f"โŒ Error: {str(e)}"
def process_brainstorm_papers(file_list):
global brainstorm_retriever
brainstorm_retriever = None
if not file_list: return "โš ๏ธ No files uploaded."
if len(file_list) > 3: return "โš ๏ธ Limit exceeded: Max 3 PDFs."
try:
combined_md = ""
names = []
for file_obj in file_list:
names.append(os.path.basename(file_obj.name))
text = process_pdf_to_markdown(file_obj.name, extract_images=False)
combined_md += f"\n\n--- PAPER: {os.path.basename(file_obj.name)} ---\n{text}\n"
unique_id = f"brainstorm_{uuid.uuid4().hex[:8]}"
splitter = MarkdownTextSplitter(chunk_size=1500, chunk_overlap=300)
chunks = splitter.create_documents([combined_md])
vectordb = Chroma.from_documents(documents=chunks, embedding=embedding_model, collection_name=unique_id)
brainstorm_retriever = vectordb.as_retriever(search_kwargs={"k": 5})
return f"โœ… Knowledge Base Ready!\n๐Ÿ“š Papers: {', '.join(names)}"
except Exception as e:
return f"โŒ Error: {str(e)}"
def transcribe_audio(audio_path):
if audio_path is None: return ""
return whisper_model.transcribe(audio_path)["text"].strip()
# ==========================================
# 4. INTELLIGENCE LAYERS (STRICT CHAT ROUTING)
# ==========================================
def ask_mistral(prompt):
try:
messages = [{"role": "user", "content": prompt}]
# This explicitly hits the /v1/chat/completions route
response = hf_client.chat.completions.create(
model="mistralai/Mistral-7B-Instruct-v0.2",
messages=messages,
max_tokens=1000,
temperature=0.3
)
return response.choices[0].message.content
except Exception as e:
return f"โš ๏ธ API Error (Mistral): {str(e)}"
def ask_qwen(prompt, image_paths):
try:
messages = [{"role": "user", "content": []}]
for img_path in image_paths:
with open(img_path, "rb") as image_file:
b64_img = base64.b64encode(image_file.read()).decode('utf-8')
messages[0]["content"].append({
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{b64_img}"}
})
messages[0]["content"].append({"type": "text", "text": prompt})
# Enforce chat completions route for Vision model too
response = hf_client.chat.completions.create(
model="Qwen/Qwen2-VL-7B-Instruct",
messages=messages,
max_tokens=150
)
return response.choices[0].message.content
except Exception as e:
return f"โš ๏ธ API Error (Qwen - Server busy): {str(e)}"
# MODE 1: CHAT WITH MAIN PAPER
def get_main_paper_response(question):
global main_paper_retriever, main_extracted_images
vision_context = ""
if main_extracted_images:
images_to_process = main_extracted_images[:3]
vision_prompt = f"Relate these images to: {question}"
vision_context = ask_qwen(vision_prompt, images_to_process)
if main_paper_retriever:
docs = main_paper_retriever.invoke(question)
text_context = "\n\n".join(d.page_content for d in docs)
prompt = f"""[INST] Use the context to answer. Integrate visual insights if available.
Markdown Context: {text_context}
Visual Insight: {vision_context}
Question: {question} [/INST]"""
return ask_mistral(prompt)
return "โš ๏ธ Please upload Main Paper."
# MODE 2: BRAINSTORM NOVELTY
def get_novelty_response(question):
global brainstorm_retriever
if not brainstorm_retriever: return "โš ๏ธ Upload Reference Papers."
docs = brainstorm_retriever.invoke(question)
context = "\n\n".join(d.page_content for d in docs)
prompt = f"""[INST] You are a Senior Research Scientist.
Analyze these papers to find gaps and novelty.
Context: {context}
Task: Identify limitations in these methodologies and suggest a NOVEL approach or gap.
Query: {question} [/INST]"""
return ask_mistral(prompt)
# MODE 3: BRAINSTORM SETUP
def get_setup_response(question):
global brainstorm_retriever
if not brainstorm_retriever: return "โš ๏ธ Upload Reference Papers."
docs = brainstorm_retriever.invoke(question)
context = "\n\n".join(d.page_content for d in docs)
prompt = f"""[INST] You are a Research Architect.
Based on the methodologies in the context, design a robust EXPERIMENTAL SETUP.
Context: {context}
Task: Propose Datasets, Evaluation Metrics, Baselines, and Hardware requirements to validate the proposed novelty.
Query: {question} [/INST]"""
return ask_mistral(prompt)
# MODE 4: GENERATE PAPER DRAFT
def get_draft_response(question):
global brainstorm_retriever
if not brainstorm_retriever: return "โš ๏ธ Upload Reference Papers."
docs = brainstorm_retriever.invoke(question)
context = "\n\n".join(d.page_content for d in docs)
prompt = f"""[INST] You are an Academic Writer.
Write a structured Research Paper Draft (Abstract, Introduction, Methodology, Experiments).
Use the context from reference papers to write the 'Related Work' section effectively.
Context: {context}
Task: Generate a draft for a paper about: {question} [/INST]"""
return ask_mistral(prompt)
# ==========================================
# 5. GRADIO UI
# ==========================================
def reset_chat(): return []
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# ๐Ÿ”ฌ AI Research Scientist Lab (Production Version)")
gr.Markdown("Pipeline: Analyze -> Find Novelty -> Design Setup -> Write Draft")
with gr.Row():
with gr.Column(scale=1):
mode_radio = gr.Radio(
choices=[
"1. Chat with Paper",
"2. Brainstorm Novelty",
"3. Brainstorm Setup",
"4. Generate Paper Draft"
],
value="1. Chat with Paper",
label="๐Ÿงช Research Stage"
)
gr.Markdown("---")
gr.Markdown("### ๐Ÿ“‚ Input Data")
file_main = gr.File(label="Target Paper (Stage 1)", file_types=[".pdf"])
status_main = gr.Textbox(label="Status", value="Waiting...", interactive=False)
file_refs = gr.File(label="Reference Papers (Stages 2-4)", file_types=[".pdf"], file_count="multiple")
status_refs = gr.Textbox(label="Status", value="Waiting...", interactive=False)
clear_btn = gr.Button("๐Ÿ—‘๏ธ Clear Workspace")
with gr.Column(scale=2):
chatbot = gr.Chatbot(label="Lab Assistant", height=700)
audio_input = gr.Audio(sources=["microphone"], type="filepath", label="๐ŸŽค Dictate Idea")
with gr.Row():
msg_input = gr.Textbox(placeholder="Enter your query or research topic...", scale=4)
send_btn = gr.Button("๐Ÿš€ Execute", variant="primary", scale=1)
file_main.change(fn=process_main_paper, inputs=file_main, outputs=status_main)
file_refs.change(fn=process_brainstorm_papers, inputs=file_refs, outputs=status_refs)
audio_input.stop_recording(fn=transcribe_audio, inputs=audio_input, outputs=msg_input)
clear_btn.click(fn=reset_chat, outputs=chatbot)
def respond(message, history, mode):
if not message.strip(): return "", history
if history is None: history = []
if mode == "1. Chat with Paper":
response = get_main_paper_response(message)
elif mode == "2. Brainstorm Novelty":
response = get_novelty_response(message)
elif mode == "3. Brainstorm Setup":
response = get_setup_response(message)
elif mode == "4. Generate Paper Draft":
response = get_draft_response(message)
else:
response = "Error: Unknown Mode"
final_ans = f"**[{mode}]**\n{response}"
save_to_json(message, final_ans, mode)
history.append({"role": "user", "content": message})
history.append({"role": "assistant", "content": final_ans})
return "", history
msg_input.submit(respond, [msg_input, chatbot, mode_radio], [msg_input, chatbot])
send_btn.click(respond, [msg_input, chatbot, mode_radio], [msg_input, chatbot])
print("๐Ÿš€ Launching Production Research Scientist Lab...")
demo.launch()