File size: 1,990 Bytes
178f14f
 
 
 
 
 
 
15ce7df
178f14f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fbc8ce9
15ce7df
178f14f
 
 
 
15ce7df
178f14f
 
15ce7df
 
178f14f
 
 
15ce7df
178f14f
 
 
 
 
15ce7df
 
178f14f
15ce7df
 
 
 
 
 
 
 
fbc8ce9
 
15ce7df
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from smolagents import CodeAgent, InferenceClientModel, DuckDuckGoSearchTool
from tools import DocumentRetrievalTool, DocumentSummarizationTool, CodeExecutionTool
from retriever import load_document, chunk_text, embed_text, vector_store
import torch
from transformers import pipeline
import gradio as gr
from huggingface_hub import InferenceClient
import os

docs = load_document("docs")     
chunks = []
for doc in docs:
    chunks.extend(chunk_text(doc["content"]))                    
embeddings = embed_text(chunks)                       

ids = [f"chunk_{i}" for i in range(len(chunks))]
metadatas = [{"source": "unknown", "chunk_index": i} for i in range(len(chunks))]

chroma_collection = vector_store(
    collection=None,
    ids=ids,
    documents=chunks,
    metadatas=metadatas,
    embeddings=embeddings
)

doc_tool = DocumentRetrievalTool(
    collection=chroma_collection
)

summarization_pipeline = pipeline(
    task="summarization",
    model="google/pegasus-xsum",
    dtype=torch.float16,
    device=-1
)

summarize_tool = DocumentSummarizationTool(
    summarization_pipeline=summarization_pipeline
)

model = InferenceClientModel(client=InferenceClient(model="naveensharma16/document-based-assistant")) # google/flan-t5-large
print("Loaded model:", model)

agent = CodeAgent(
    tools=[doc_tool, summarize_tool, DuckDuckGoSearchTool(), CodeExecutionTool()],
    model=model,
    stream_outputs=False # True
)

def predict(query: str):
    """GAIA entrypoint"""
    return agent.run(query)

iface = gr.Interface(
    fn=predict,
    inputs="text",
    outputs="text",
    title="Document QA Agent"
)

# def agent_interface(query):
#     return agent.run(query)

# iface = gr.Interface(
#     fn=agent_interface,
#     inputs="text",
#     outputs="text",
#     title="Document QA Agent"
# )

if __name__ == "__main__":
    # if os.getenv("RUN_GAIA", "true") == "true":
    #     run_gaia_test("naveensharma16")
    iface.launch(server_name="0.0.0.0", server_port=7860)