|
|
from smolagents import CodeAgent, InferenceClientModel, DuckDuckGoSearchTool |
|
|
from tools import DocumentRetrievalTool, DocumentSummarizationTool, CodeExecutionTool |
|
|
from retriever import load_document, chunk_text, embed_text, vector_store |
|
|
import torch |
|
|
from transformers import pipeline |
|
|
import gradio as gr |
|
|
from huggingface_hub import InferenceClient |
|
|
import os |
|
|
|
|
|
docs = load_document("docs") |
|
|
chunks = [] |
|
|
for doc in docs: |
|
|
chunks.extend(chunk_text(doc["content"])) |
|
|
embeddings = embed_text(chunks) |
|
|
|
|
|
ids = [f"chunk_{i}" for i in range(len(chunks))] |
|
|
metadatas = [{"source": "unknown", "chunk_index": i} for i in range(len(chunks))] |
|
|
|
|
|
chroma_collection = vector_store( |
|
|
collection=None, |
|
|
ids=ids, |
|
|
documents=chunks, |
|
|
metadatas=metadatas, |
|
|
embeddings=embeddings |
|
|
) |
|
|
|
|
|
doc_tool = DocumentRetrievalTool( |
|
|
collection=chroma_collection |
|
|
) |
|
|
|
|
|
summarization_pipeline = pipeline( |
|
|
task="summarization", |
|
|
model="google/pegasus-xsum", |
|
|
dtype=torch.float16, |
|
|
device=-1 |
|
|
) |
|
|
|
|
|
summarize_tool = DocumentSummarizationTool( |
|
|
summarization_pipeline=summarization_pipeline |
|
|
) |
|
|
|
|
|
model = InferenceClientModel(client=InferenceClient(model="naveensharma16/document-based-assistant")) |
|
|
print("Loaded model:", model) |
|
|
|
|
|
agent = CodeAgent( |
|
|
tools=[doc_tool, summarize_tool, DuckDuckGoSearchTool(), CodeExecutionTool()], |
|
|
model=model, |
|
|
stream_outputs=False |
|
|
) |
|
|
|
|
|
def predict(query: str): |
|
|
"""GAIA entrypoint""" |
|
|
return agent.run(query) |
|
|
|
|
|
iface = gr.Interface( |
|
|
fn=predict, |
|
|
inputs="text", |
|
|
outputs="text", |
|
|
title="Document QA Agent" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
|
iface.launch(server_name="0.0.0.0", server_port=7860) |