File size: 5,174 Bytes
20cfdf2
09d7751
 
e99a67e
09d7751
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20cfdf2
 
 
09d7751
20cfdf2
 
 
09d7751
 
 
20cfdf2
09d7751
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
624cf2a
09d7751
 
 
 
 
 
 
 
 
 
 
20cfdf2
 
 
 
 
 
09d7751
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20cfdf2
09d7751
20cfdf2
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
import gradio as gr
from transformers import AutoTokenizer
import os
import spaces
import torch
from llama_index.llms.huggingface import HuggingFaceLLM

# Optional quantization to 4bit
from transformers import BitsAndBytesConfig
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.core import Settings

import faiss
from llama_index.core import (
    load_index_from_storage,
    StorageContext,
)
from llama_index.vector_stores.faiss import FaissVectorStore
from llama_index.core.tools import QueryEngineTool, ToolMetadata

import json
from typing import Sequence, List

from llama_index.core.llms import ChatMessage
from llama_index.core.tools import BaseTool, FunctionTool
from llama_index.core.agent import ReActAgent

import nest_asyncio

from llama_index.core.tools import QueryEngineTool, ToolMetadata

HF_TOKEN = os.environ.get("HF_TOKEN", None)

DESCRIPTION = '''
<div>
<h1 style="text-align: center;">Mistral 7B Instruct v0.3</h1>
<p>This Space demonstrates the Agent based RAG on multiple documents using Gemma 2b it and llama index</p>
</div>
'''

tokenizer = AutoTokenizer.from_pretrained(
    "google/gemma-1.1-2b-it",
    token=HF_TOKEN,
)

stopping_ids = [
    tokenizer.eos_token_id,
    tokenizer.convert_tokens_to_ids("<|eot_id|>"),
]


quantization_config = BitsAndBytesConfig(
    load_in_4bit = True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_quant_type = "nf4",
    bnb_4bit_use_double_quant = True,
)

llm = HuggingFaceLLM(
    model_name = "google/gemma-1.1-2b-it",
    model_kwargs = {
        "token": HF_TOKEN,
        "torch_dtype": torch.bfloat16,  # comment this line and uncomment below to use 4bit
        #"quantization_config": quantization_config
    },
    generate_kwargs = {
        "do_sample": True,
        "temperature": 0.6,
        "top_p": 0.9,
    },
    tokenizer_name = "google/gemma-1.1-2b-it",
    tokenizer_kwargs = {"token": HF_TOKEN},
    stopping_ids = stopping_ids,
)



embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-large-en-v1.5")



# dimensions of bge-large-en-v1.5 obtained from https://huggingface.co/BAAI/bge-large-en-v1.5
d = 1024
faiss_index = faiss.IndexFlatL2(d)


nest_asyncio.apply()

# bge embedding model
Settings.embed_model = embed_model

# GPU - Llama-3-8B-Instruct model
# CPU - Gemma 1.1 2B it instruct
Settings.llm = llm

# rebuild storage context

geoVectorStore = FaissVectorStore.from_persist_dir("./geoindex/")

geoStorageContext = StorageContext.from_defaults(
    vector_store=geoVectorStore, persist_dir="./geoindex/")

geoindex = load_index_from_storage(storage_context=geoStorageContext)

bioVectorStore = FaissVectorStore.from_persist_dir("./bioindex/")

bioStorageContext = StorageContext.from_defaults(
    vector_store=bioVectorStore, persist_dir="./bioindex/")

bioindex = load_index_from_storage(storage_context=geoStorageContext)

geo_engine = geoindex.as_query_engine(similarity_top_k=3)
bio_engine = bioindex.as_query_engine(similarity_top_k=3)

query_engine_tools = [
    QueryEngineTool(
        query_engine=geo_engine,
        metadata=ToolMetadata(
            name="geography",
            description=(
                "This is a geography textbook, it provides information about geography. "
                "Use a detailed plain text question as input to the tool."
            ),
        ),
    ),
    QueryEngineTool(
        query_engine=bio_engine,
        metadata=ToolMetadata(
            name="biology",
            description=(
                "This is a biology textbook it provides information about biology. "
                "Use a detailed plain text question as input to the tool."
            ),
        ),
    ),
]

agent = ReActAgent.from_tools(
    query_engine_tools,
    llm=llm,
    verbose=False,
)

@spaces.GPU(duration=120)
def respond(
    message,
    # history: list[tuple[str, str]],
    # system_message,
    # max_tokens,
    # temperature,
    # top_p,
):
    prompt=f'''Analyze the question: {message} and use appropriate tool to get the relevant context and answer the question, do not answer on your own and output only Observation'''
    response = agent.chat(prompt)
    return print(str(response))

"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
    respond,
    # additional_inputs=[
    #     gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
    #     gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
    #     gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
    #     gr.Slider(
    #         minimum=0.1,
    #         maximum=1.0,
    #         value=0.95,
    #         step=0.05,
    #         label="Top-p (nucleus sampling)",
    #     ),
    # ],
    examples=[
        ["What are different types of rural settlement?"],
        ["Explain Urbanisation in India?"],
        ["What was the level of urbanisation in India in 2011?"],
        ["List the religious and cultural towns in India?"],
    ],
    cache_examples=False,
)


if __name__ == "__main__":
    demo.launch()