Spaces:
Sleeping
Sleeping
| # ========================================== | |
| # 1. INITIAL SETUP & LIBRARIES | |
| # ========================================== | |
| import os | |
| import json | |
| import uuid | |
| import base64 | |
| import whisper | |
| import pymupdf4llm | |
| import gradio as gr | |
| from datetime import datetime | |
| from huggingface_hub import InferenceClient | |
| from langchain_text_splitters import MarkdownTextSplitter | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from langchain_community.vectorstores import Chroma | |
| # ========================================== | |
| # 2. CONNECT TO AI APIS (OpenAI-Compatible) | |
| # ========================================== | |
| print("โณ Connecting to Hugging Face APIs...") | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| if not HF_TOKEN: | |
| print("โ ๏ธ WARNING: HF_TOKEN not found! The AI will not be able to generate responses.") | |
| # Initialize a single, generic client | |
| # We do NOT bind the model name here to prevent the "text-generation" tagging error | |
| hf_client = InferenceClient(api_key=HF_TOKEN) | |
| # Local Embeddings & Whisper | |
| embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
| whisper_model = whisper.load_model("base") | |
| print("โ APIs and Local Models Loaded Successfully!") | |
| # ========================================== | |
| # 3. GLOBAL STATE & HELPERS | |
| # ========================================== | |
| main_paper_retriever = None | |
| brainstorm_retriever = None | |
| main_extracted_images = [] | |
| chat_history_file = "research_lab_history.json" | |
| if not os.path.exists(chat_history_file): | |
| with open(chat_history_file, "w") as f: json.dump([], f) | |
| def save_to_json(user_msg, combined_ans, mode): | |
| timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
| entry = {"timestamp": timestamp, "mode": mode, "user": user_msg, "assistant": combined_ans} | |
| try: | |
| with open(chat_history_file, "r") as f: history = json.load(f) | |
| except: history = [] | |
| history.append(entry) | |
| with open(chat_history_file, "w") as f: json.dump(history, f, indent=4) | |
| def process_pdf_to_markdown(pdf_path, extract_images=True): | |
| global main_extracted_images | |
| output_image_dir = "extracted_images" | |
| if extract_images: | |
| if os.path.exists(output_image_dir): | |
| for f in os.listdir(output_image_dir): os.remove(os.path.join(output_image_dir, f)) | |
| else: | |
| os.makedirs(output_image_dir, exist_ok=True) | |
| try: | |
| if extract_images: | |
| md_text = pymupdf4llm.to_markdown(pdf_path, write_images=True, image_path=output_image_dir, image_format="png") | |
| main_extracted_images = [os.path.join(output_image_dir, f) for f in os.listdir(output_image_dir) if f.endswith(('.png', '.jpg'))] | |
| main_extracted_images.sort() | |
| else: | |
| md_text = pymupdf4llm.to_markdown(pdf_path, write_images=False) | |
| return md_text | |
| except Exception as e: | |
| return "" | |
| def process_main_paper(file_obj): | |
| global main_paper_retriever | |
| main_paper_retriever = None | |
| if file_obj is None: return "โ ๏ธ No file uploaded." | |
| try: | |
| unique_id = f"main_{uuid.uuid4().hex[:8]}" | |
| md_content = process_pdf_to_markdown(file_obj.name, extract_images=True) | |
| splitter = MarkdownTextSplitter(chunk_size=1000, chunk_overlap=200) | |
| chunks = splitter.create_documents([md_content]) | |
| vectordb = Chroma.from_documents(documents=chunks, embedding=embedding_model, collection_name=unique_id) | |
| main_paper_retriever = vectordb.as_retriever(search_kwargs={"k": 3}) | |
| return f"โ Main Paper Ready!\n๐ Text: Indexed\n๐๏ธ Images: {len(main_extracted_images)} Extracted" | |
| except Exception as e: | |
| return f"โ Error: {str(e)}" | |
| def process_brainstorm_papers(file_list): | |
| global brainstorm_retriever | |
| brainstorm_retriever = None | |
| if not file_list: return "โ ๏ธ No files uploaded." | |
| if len(file_list) > 3: return "โ ๏ธ Limit exceeded: Max 3 PDFs." | |
| try: | |
| combined_md = "" | |
| names = [] | |
| for file_obj in file_list: | |
| names.append(os.path.basename(file_obj.name)) | |
| text = process_pdf_to_markdown(file_obj.name, extract_images=False) | |
| combined_md += f"\n\n--- PAPER: {os.path.basename(file_obj.name)} ---\n{text}\n" | |
| unique_id = f"brainstorm_{uuid.uuid4().hex[:8]}" | |
| splitter = MarkdownTextSplitter(chunk_size=1500, chunk_overlap=300) | |
| chunks = splitter.create_documents([combined_md]) | |
| vectordb = Chroma.from_documents(documents=chunks, embedding=embedding_model, collection_name=unique_id) | |
| brainstorm_retriever = vectordb.as_retriever(search_kwargs={"k": 5}) | |
| return f"โ Knowledge Base Ready!\n๐ Papers: {', '.join(names)}" | |
| except Exception as e: | |
| return f"โ Error: {str(e)}" | |
| def transcribe_audio(audio_path): | |
| if audio_path is None: return "" | |
| return whisper_model.transcribe(audio_path)["text"].strip() | |
| # ========================================== | |
| # 4. INTELLIGENCE LAYERS (STRICT CHAT ROUTING) | |
| # ========================================== | |
| def ask_mistral(prompt): | |
| try: | |
| messages = [{"role": "user", "content": prompt}] | |
| # This explicitly hits the /v1/chat/completions route | |
| response = hf_client.chat.completions.create( | |
| model="mistralai/Mistral-7B-Instruct-v0.2", | |
| messages=messages, | |
| max_tokens=1000, | |
| temperature=0.3 | |
| ) | |
| return response.choices[0].message.content | |
| except Exception as e: | |
| return f"โ ๏ธ API Error (Mistral): {str(e)}" | |
| def ask_qwen(prompt, image_paths): | |
| try: | |
| messages = [{"role": "user", "content": []}] | |
| for img_path in image_paths: | |
| with open(img_path, "rb") as image_file: | |
| b64_img = base64.b64encode(image_file.read()).decode('utf-8') | |
| messages[0]["content"].append({ | |
| "type": "image_url", | |
| "image_url": {"url": f"data:image/png;base64,{b64_img}"} | |
| }) | |
| messages[0]["content"].append({"type": "text", "text": prompt}) | |
| # Enforce chat completions route for Vision model too | |
| response = hf_client.chat.completions.create( | |
| model="Qwen/Qwen2-VL-7B-Instruct", | |
| messages=messages, | |
| max_tokens=150 | |
| ) | |
| return response.choices[0].message.content | |
| except Exception as e: | |
| return f"โ ๏ธ API Error (Qwen - Server busy): {str(e)}" | |
| # MODE 1: CHAT WITH MAIN PAPER | |
| def get_main_paper_response(question): | |
| global main_paper_retriever, main_extracted_images | |
| vision_context = "" | |
| if main_extracted_images: | |
| images_to_process = main_extracted_images[:3] | |
| vision_prompt = f"Relate these images to: {question}" | |
| vision_context = ask_qwen(vision_prompt, images_to_process) | |
| if main_paper_retriever: | |
| docs = main_paper_retriever.invoke(question) | |
| text_context = "\n\n".join(d.page_content for d in docs) | |
| prompt = f"""[INST] Use the context to answer. Integrate visual insights if available. | |
| Markdown Context: {text_context} | |
| Visual Insight: {vision_context} | |
| Question: {question} [/INST]""" | |
| return ask_mistral(prompt) | |
| return "โ ๏ธ Please upload Main Paper." | |
| # MODE 2: BRAINSTORM NOVELTY | |
| def get_novelty_response(question): | |
| global brainstorm_retriever | |
| if not brainstorm_retriever: return "โ ๏ธ Upload Reference Papers." | |
| docs = brainstorm_retriever.invoke(question) | |
| context = "\n\n".join(d.page_content for d in docs) | |
| prompt = f"""[INST] You are a Senior Research Scientist. | |
| Analyze these papers to find gaps and novelty. | |
| Context: {context} | |
| Task: Identify limitations in these methodologies and suggest a NOVEL approach or gap. | |
| Query: {question} [/INST]""" | |
| return ask_mistral(prompt) | |
| # MODE 3: BRAINSTORM SETUP | |
| def get_setup_response(question): | |
| global brainstorm_retriever | |
| if not brainstorm_retriever: return "โ ๏ธ Upload Reference Papers." | |
| docs = brainstorm_retriever.invoke(question) | |
| context = "\n\n".join(d.page_content for d in docs) | |
| prompt = f"""[INST] You are a Research Architect. | |
| Based on the methodologies in the context, design a robust EXPERIMENTAL SETUP. | |
| Context: {context} | |
| Task: Propose Datasets, Evaluation Metrics, Baselines, and Hardware requirements to validate the proposed novelty. | |
| Query: {question} [/INST]""" | |
| return ask_mistral(prompt) | |
| # MODE 4: GENERATE PAPER DRAFT | |
| def get_draft_response(question): | |
| global brainstorm_retriever | |
| if not brainstorm_retriever: return "โ ๏ธ Upload Reference Papers." | |
| docs = brainstorm_retriever.invoke(question) | |
| context = "\n\n".join(d.page_content for d in docs) | |
| prompt = f"""[INST] You are an Academic Writer. | |
| Write a structured Research Paper Draft (Abstract, Introduction, Methodology, Experiments). | |
| Use the context from reference papers to write the 'Related Work' section effectively. | |
| Context: {context} | |
| Task: Generate a draft for a paper about: {question} [/INST]""" | |
| return ask_mistral(prompt) | |
| # ========================================== | |
| # 5. GRADIO UI | |
| # ========================================== | |
| def reset_chat(): return [] | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# ๐ฌ AI Research Scientist Lab (Production Version)") | |
| gr.Markdown("Pipeline: Analyze -> Find Novelty -> Design Setup -> Write Draft") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| mode_radio = gr.Radio( | |
| choices=[ | |
| "1. Chat with Paper", | |
| "2. Brainstorm Novelty", | |
| "3. Brainstorm Setup", | |
| "4. Generate Paper Draft" | |
| ], | |
| value="1. Chat with Paper", | |
| label="๐งช Research Stage" | |
| ) | |
| gr.Markdown("---") | |
| gr.Markdown("### ๐ Input Data") | |
| file_main = gr.File(label="Target Paper (Stage 1)", file_types=[".pdf"]) | |
| status_main = gr.Textbox(label="Status", value="Waiting...", interactive=False) | |
| file_refs = gr.File(label="Reference Papers (Stages 2-4)", file_types=[".pdf"], file_count="multiple") | |
| status_refs = gr.Textbox(label="Status", value="Waiting...", interactive=False) | |
| clear_btn = gr.Button("๐๏ธ Clear Workspace") | |
| with gr.Column(scale=2): | |
| chatbot = gr.Chatbot(label="Lab Assistant", height=700) | |
| audio_input = gr.Audio(sources=["microphone"], type="filepath", label="๐ค Dictate Idea") | |
| with gr.Row(): | |
| msg_input = gr.Textbox(placeholder="Enter your query or research topic...", scale=4) | |
| send_btn = gr.Button("๐ Execute", variant="primary", scale=1) | |
| file_main.change(fn=process_main_paper, inputs=file_main, outputs=status_main) | |
| file_refs.change(fn=process_brainstorm_papers, inputs=file_refs, outputs=status_refs) | |
| audio_input.stop_recording(fn=transcribe_audio, inputs=audio_input, outputs=msg_input) | |
| clear_btn.click(fn=reset_chat, outputs=chatbot) | |
| def respond(message, history, mode): | |
| if not message.strip(): return "", history | |
| if history is None: history = [] | |
| if mode == "1. Chat with Paper": | |
| response = get_main_paper_response(message) | |
| elif mode == "2. Brainstorm Novelty": | |
| response = get_novelty_response(message) | |
| elif mode == "3. Brainstorm Setup": | |
| response = get_setup_response(message) | |
| elif mode == "4. Generate Paper Draft": | |
| response = get_draft_response(message) | |
| else: | |
| response = "Error: Unknown Mode" | |
| final_ans = f"**[{mode}]**\n{response}" | |
| save_to_json(message, final_ans, mode) | |
| history.append({"role": "user", "content": message}) | |
| history.append({"role": "assistant", "content": final_ans}) | |
| return "", history | |
| msg_input.submit(respond, [msg_input, chatbot, mode_radio], [msg_input, chatbot]) | |
| send_btn.click(respond, [msg_input, chatbot, mode_radio], [msg_input, chatbot]) | |
| print("๐ Launching Production Research Scientist Lab...") | |
| demo.launch() |