File size: 1,990 Bytes
178f14f 15ce7df 178f14f fbc8ce9 15ce7df 178f14f 15ce7df 178f14f 15ce7df 178f14f 15ce7df 178f14f 15ce7df 178f14f 15ce7df fbc8ce9 15ce7df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
from smolagents import CodeAgent, InferenceClientModel, DuckDuckGoSearchTool
from tools import DocumentRetrievalTool, DocumentSummarizationTool, CodeExecutionTool
from retriever import load_document, chunk_text, embed_text, vector_store
import torch
from transformers import pipeline
import gradio as gr
from huggingface_hub import InferenceClient
import os
docs = load_document("docs")
chunks = []
for doc in docs:
chunks.extend(chunk_text(doc["content"]))
embeddings = embed_text(chunks)
ids = [f"chunk_{i}" for i in range(len(chunks))]
metadatas = [{"source": "unknown", "chunk_index": i} for i in range(len(chunks))]
chroma_collection = vector_store(
collection=None,
ids=ids,
documents=chunks,
metadatas=metadatas,
embeddings=embeddings
)
doc_tool = DocumentRetrievalTool(
collection=chroma_collection
)
summarization_pipeline = pipeline(
task="summarization",
model="google/pegasus-xsum",
dtype=torch.float16,
device=-1
)
summarize_tool = DocumentSummarizationTool(
summarization_pipeline=summarization_pipeline
)
model = InferenceClientModel(client=InferenceClient(model="naveensharma16/document-based-assistant")) # google/flan-t5-large
print("Loaded model:", model)
agent = CodeAgent(
tools=[doc_tool, summarize_tool, DuckDuckGoSearchTool(), CodeExecutionTool()],
model=model,
stream_outputs=False # True
)
def predict(query: str):
"""GAIA entrypoint"""
return agent.run(query)
iface = gr.Interface(
fn=predict,
inputs="text",
outputs="text",
title="Document QA Agent"
)
# def agent_interface(query):
# return agent.run(query)
# iface = gr.Interface(
# fn=agent_interface,
# inputs="text",
# outputs="text",
# title="Document QA Agent"
# )
if __name__ == "__main__":
# if os.getenv("RUN_GAIA", "true") == "true":
# run_gaia_test("naveensharma16")
iface.launch(server_name="0.0.0.0", server_port=7860) |