Spaces:
Sleeping
Sleeping
File size: 12,033 Bytes
ae05a75 14d7aab ae05a75 14d7aab ae05a75 14d7aab ae05a75 14d7aab ae05a75 68b6374 14d7aab 74238a2 14d7aab 68b6374 ae05a75 14d7aab ae05a75 14d7aab ae05a75 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 | # ==========================================
# 1. INITIAL SETUP & LIBRARIES
# ==========================================
import os
import json
import uuid
import base64
import whisper
import pymupdf4llm
import gradio as gr
from datetime import datetime
from huggingface_hub import InferenceClient
from langchain_text_splitters import MarkdownTextSplitter
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import Chroma
# ==========================================
# 2. CONNECT TO AI APIS (OpenAI-Compatible)
# ==========================================
print("β³ Connecting to Hugging Face APIs...")
HF_TOKEN = os.environ.get("HF_TOKEN")
if not HF_TOKEN:
print("β οΈ WARNING: HF_TOKEN not found! The AI will not be able to generate responses.")
# Initialize a single, generic client
# We do NOT bind the model name here to prevent the "text-generation" tagging error
hf_client = InferenceClient(api_key=HF_TOKEN)
# Local Embeddings & Whisper
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
whisper_model = whisper.load_model("base")
print("β
APIs and Local Models Loaded Successfully!")
# ==========================================
# 3. GLOBAL STATE & HELPERS
# ==========================================
main_paper_retriever = None
brainstorm_retriever = None
main_extracted_images = []
chat_history_file = "research_lab_history.json"
if not os.path.exists(chat_history_file):
with open(chat_history_file, "w") as f: json.dump([], f)
def save_to_json(user_msg, combined_ans, mode):
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
entry = {"timestamp": timestamp, "mode": mode, "user": user_msg, "assistant": combined_ans}
try:
with open(chat_history_file, "r") as f: history = json.load(f)
except: history = []
history.append(entry)
with open(chat_history_file, "w") as f: json.dump(history, f, indent=4)
def process_pdf_to_markdown(pdf_path, extract_images=True):
global main_extracted_images
output_image_dir = "extracted_images"
if extract_images:
if os.path.exists(output_image_dir):
for f in os.listdir(output_image_dir): os.remove(os.path.join(output_image_dir, f))
else:
os.makedirs(output_image_dir, exist_ok=True)
try:
if extract_images:
md_text = pymupdf4llm.to_markdown(pdf_path, write_images=True, image_path=output_image_dir, image_format="png")
main_extracted_images = [os.path.join(output_image_dir, f) for f in os.listdir(output_image_dir) if f.endswith(('.png', '.jpg'))]
main_extracted_images.sort()
else:
md_text = pymupdf4llm.to_markdown(pdf_path, write_images=False)
return md_text
except Exception as e:
return ""
def process_main_paper(file_obj):
global main_paper_retriever
main_paper_retriever = None
if file_obj is None: return "β οΈ No file uploaded."
try:
unique_id = f"main_{uuid.uuid4().hex[:8]}"
md_content = process_pdf_to_markdown(file_obj.name, extract_images=True)
splitter = MarkdownTextSplitter(chunk_size=1000, chunk_overlap=200)
chunks = splitter.create_documents([md_content])
vectordb = Chroma.from_documents(documents=chunks, embedding=embedding_model, collection_name=unique_id)
main_paper_retriever = vectordb.as_retriever(search_kwargs={"k": 3})
return f"β
Main Paper Ready!\nπ Text: Indexed\nποΈ Images: {len(main_extracted_images)} Extracted"
except Exception as e:
return f"β Error: {str(e)}"
def process_brainstorm_papers(file_list):
global brainstorm_retriever
brainstorm_retriever = None
if not file_list: return "β οΈ No files uploaded."
if len(file_list) > 3: return "β οΈ Limit exceeded: Max 3 PDFs."
try:
combined_md = ""
names = []
for file_obj in file_list:
names.append(os.path.basename(file_obj.name))
text = process_pdf_to_markdown(file_obj.name, extract_images=False)
combined_md += f"\n\n--- PAPER: {os.path.basename(file_obj.name)} ---\n{text}\n"
unique_id = f"brainstorm_{uuid.uuid4().hex[:8]}"
splitter = MarkdownTextSplitter(chunk_size=1500, chunk_overlap=300)
chunks = splitter.create_documents([combined_md])
vectordb = Chroma.from_documents(documents=chunks, embedding=embedding_model, collection_name=unique_id)
brainstorm_retriever = vectordb.as_retriever(search_kwargs={"k": 5})
return f"β
Knowledge Base Ready!\nπ Papers: {', '.join(names)}"
except Exception as e:
return f"β Error: {str(e)}"
def transcribe_audio(audio_path):
if audio_path is None: return ""
return whisper_model.transcribe(audio_path)["text"].strip()
# ==========================================
# 4. INTELLIGENCE LAYERS (STRICT CHAT ROUTING)
# ==========================================
def ask_mistral(prompt):
try:
messages = [{"role": "user", "content": prompt}]
# This explicitly hits the /v1/chat/completions route
response = hf_client.chat.completions.create(
model="mistralai/Mistral-7B-Instruct-v0.2",
messages=messages,
max_tokens=1000,
temperature=0.3
)
return response.choices[0].message.content
except Exception as e:
return f"β οΈ API Error (Mistral): {str(e)}"
def ask_qwen(prompt, image_paths):
try:
messages = [{"role": "user", "content": []}]
for img_path in image_paths:
with open(img_path, "rb") as image_file:
b64_img = base64.b64encode(image_file.read()).decode('utf-8')
messages[0]["content"].append({
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{b64_img}"}
})
messages[0]["content"].append({"type": "text", "text": prompt})
# Enforce chat completions route for Vision model too
response = hf_client.chat.completions.create(
model="Qwen/Qwen2-VL-7B-Instruct",
messages=messages,
max_tokens=150
)
return response.choices[0].message.content
except Exception as e:
return f"β οΈ API Error (Qwen - Server busy): {str(e)}"
# MODE 1: CHAT WITH MAIN PAPER
def get_main_paper_response(question):
global main_paper_retriever, main_extracted_images
vision_context = ""
if main_extracted_images:
images_to_process = main_extracted_images[:3]
vision_prompt = f"Relate these images to: {question}"
vision_context = ask_qwen(vision_prompt, images_to_process)
if main_paper_retriever:
docs = main_paper_retriever.invoke(question)
text_context = "\n\n".join(d.page_content for d in docs)
prompt = f"""[INST] Use the context to answer. Integrate visual insights if available.
Markdown Context: {text_context}
Visual Insight: {vision_context}
Question: {question} [/INST]"""
return ask_mistral(prompt)
return "β οΈ Please upload Main Paper."
# MODE 2: BRAINSTORM NOVELTY
def get_novelty_response(question):
global brainstorm_retriever
if not brainstorm_retriever: return "β οΈ Upload Reference Papers."
docs = brainstorm_retriever.invoke(question)
context = "\n\n".join(d.page_content for d in docs)
prompt = f"""[INST] You are a Senior Research Scientist.
Analyze these papers to find gaps and novelty.
Context: {context}
Task: Identify limitations in these methodologies and suggest a NOVEL approach or gap.
Query: {question} [/INST]"""
return ask_mistral(prompt)
# MODE 3: BRAINSTORM SETUP
def get_setup_response(question):
global brainstorm_retriever
if not brainstorm_retriever: return "β οΈ Upload Reference Papers."
docs = brainstorm_retriever.invoke(question)
context = "\n\n".join(d.page_content for d in docs)
prompt = f"""[INST] You are a Research Architect.
Based on the methodologies in the context, design a robust EXPERIMENTAL SETUP.
Context: {context}
Task: Propose Datasets, Evaluation Metrics, Baselines, and Hardware requirements to validate the proposed novelty.
Query: {question} [/INST]"""
return ask_mistral(prompt)
# MODE 4: GENERATE PAPER DRAFT
def get_draft_response(question):
global brainstorm_retriever
if not brainstorm_retriever: return "β οΈ Upload Reference Papers."
docs = brainstorm_retriever.invoke(question)
context = "\n\n".join(d.page_content for d in docs)
prompt = f"""[INST] You are an Academic Writer.
Write a structured Research Paper Draft (Abstract, Introduction, Methodology, Experiments).
Use the context from reference papers to write the 'Related Work' section effectively.
Context: {context}
Task: Generate a draft for a paper about: {question} [/INST]"""
return ask_mistral(prompt)
# ==========================================
# 5. GRADIO UI
# ==========================================
def reset_chat(): return []
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# π¬ AI Research Scientist Lab (Production Version)")
gr.Markdown("Pipeline: Analyze -> Find Novelty -> Design Setup -> Write Draft")
with gr.Row():
with gr.Column(scale=1):
mode_radio = gr.Radio(
choices=[
"1. Chat with Paper",
"2. Brainstorm Novelty",
"3. Brainstorm Setup",
"4. Generate Paper Draft"
],
value="1. Chat with Paper",
label="π§ͺ Research Stage"
)
gr.Markdown("---")
gr.Markdown("### π Input Data")
file_main = gr.File(label="Target Paper (Stage 1)", file_types=[".pdf"])
status_main = gr.Textbox(label="Status", value="Waiting...", interactive=False)
file_refs = gr.File(label="Reference Papers (Stages 2-4)", file_types=[".pdf"], file_count="multiple")
status_refs = gr.Textbox(label="Status", value="Waiting...", interactive=False)
clear_btn = gr.Button("ποΈ Clear Workspace")
with gr.Column(scale=2):
chatbot = gr.Chatbot(label="Lab Assistant", height=700)
audio_input = gr.Audio(sources=["microphone"], type="filepath", label="π€ Dictate Idea")
with gr.Row():
msg_input = gr.Textbox(placeholder="Enter your query or research topic...", scale=4)
send_btn = gr.Button("π Execute", variant="primary", scale=1)
file_main.change(fn=process_main_paper, inputs=file_main, outputs=status_main)
file_refs.change(fn=process_brainstorm_papers, inputs=file_refs, outputs=status_refs)
audio_input.stop_recording(fn=transcribe_audio, inputs=audio_input, outputs=msg_input)
clear_btn.click(fn=reset_chat, outputs=chatbot)
def respond(message, history, mode):
if not message.strip(): return "", history
if history is None: history = []
if mode == "1. Chat with Paper":
response = get_main_paper_response(message)
elif mode == "2. Brainstorm Novelty":
response = get_novelty_response(message)
elif mode == "3. Brainstorm Setup":
response = get_setup_response(message)
elif mode == "4. Generate Paper Draft":
response = get_draft_response(message)
else:
response = "Error: Unknown Mode"
final_ans = f"**[{mode}]**\n{response}"
save_to_json(message, final_ans, mode)
history.append({"role": "user", "content": message})
history.append({"role": "assistant", "content": final_ans})
return "", history
msg_input.submit(respond, [msg_input, chatbot, mode_radio], [msg_input, chatbot])
send_btn.click(respond, [msg_input, chatbot, mode_radio], [msg_input, chatbot])
print("π Launching Production Research Scientist Lab...")
demo.launch() |