Ashmi Banerjee commited on
Commit
b6dfdd7
Β·
1 Parent(s): 4e4eaaf
Files changed (1) hide show
  1. app.py +48 -108
app.py CHANGED
@@ -1,121 +1,61 @@
1
  from typing import Optional
2
- from sentence_transformers import SentenceTransformer
3
- import pymongo
4
- import os
5
- from huggingface_hub import InferenceClient
6
  import gradio as gr
 
 
 
7
 
8
 
9
- HF_token = os.environ["HF_TOKEN"]
 
10
 
11
 
12
- def get_embedding(text: str) -> list[float]:
13
- embedding_model = SentenceTransformer("thenlper/gte-large")
14
-
15
- if not text.strip():
16
- print("Attempted to get embedding for empty text.")
17
- return []
18
-
19
- embedding = embedding_model.encode(text)
20
-
21
- return embedding.tolist()
22
-
23
-
24
- def get_mongo_client(mongo_url):
25
- """Establish connection to the MongoDB."""
26
- if not mongo_url:
27
- print("MONGO_URI not set in environment variables")
28
- try:
29
- client = pymongo.MongoClient(mongo_url)
30
- print("Connection to MongoDB successful")
31
- return client
32
- except pymongo.errors.ConnectionFailure as e:
33
- print(f"Connection failed: {e}")
34
- return None
35
-
36
-
37
- def get_mongo_url():
38
- username = os.environ["MONGO_USERNAME"]
39
- password = os.environ["MONGO_PW"]
40
- mongo_url = f"mongodb+srv://{username}:{password}@cluster0.62unmco.mongodb.net/"
41
- return mongo_url
42
-
43
-
44
- def query_results(query, mongo_url):
45
- mongo_client = get_mongo_client(mongo_url)
46
- db = mongo_client["EU_Cities"]
47
-
48
- query_embedding = get_embedding(query)
49
- results = db.EU_cities_collection.aggregate([
50
- {
51
- "$vectorSearch": {
52
- "index": "vector_index",
53
- "path": "embedding",
54
- "queryVector": query_embedding,
55
- "numCandidates": 150,
56
- "limit": 5
57
- }
58
- }
59
- ])
60
- return results
61
-
62
-
63
- def get_search_result(query, mongo_url):
64
- get_knowledge = query_results(query, mongo_url)
65
- print(get_knowledge)
66
-
67
- search_result = ""
68
- for result in get_knowledge:
69
- search_result += f"City: {result.get('city', 'N/A')}, Abstract: {result.get('combined', 'N/A')}\n"
70
-
71
- return search_result
72
-
73
-
74
- def generate_text(query, model_name: Optional[str] = "google/gemma-2b-it"):
75
- if model_name is None:
76
- model_name = "google/gemma-2b-it"
77
-
78
- mongo_url = get_mongo_url()
79
- source_information = get_search_result(query, mongo_url)
80
- combined_information = (
81
- f"Query: {query}\nContinue to answer the query by using the Search Results:\n{source_information}."
82
- )
83
- client = InferenceClient(model_name, token=HF_token)
84
-
85
- stream = client.text_generation(prompt=combined_information, details=True, stream=True, max_new_tokens=2048,
86
- return_full_text=False)
87
- output = ""
88
-
89
- for response in stream:
90
- output += response.token.text
91
-
92
- if "<eos>" in output:
93
- output = output.split("<eos>")[0]
94
- return output
95
 
96
 
97
  examples = [["I'm planning a vacation to France. Can you suggest a one-week itinerary including must-visit places and "
98
- "local cuisines to try?", None],
99
- ["I want to explore off-the-beaten-path destinations in Europe, any suggestions?", None],
100
- ["Suggest some cities that can be visited from London and are very rich in history and culture.", None],
 
101
  ]
102
 
103
- demo = gr.Interface(
104
- fn=generate_text,
105
- inputs=["text",
106
- gr.Dropdown(
107
- ["google/gemma-2b-it"], label="Models", info="Will "
108
- "add "
109
- "more "
110
- "models "
111
- "later! "
112
- ),
113
- ],
114
- title="πŸ‡ͺπŸ‡Ί Euro TravelBot πŸ‡ͺπŸ‡Ί",
115
- description="Explore Europe with ease using our prototype app! We're testing the compatibility of RAG implementations with Google Gemma-2b-it models to generate travel recommendations. This early version (read quick and dirty implementation) aims to see if functionalities work smoothly. It relies on Wikipedia abstracts from 160 European cities to provide answers to your questions. Please be kind with it as it's a work in progress!",
116
- outputs=["text"],
117
- examples=examples,
118
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
  if __name__ == "__main__":
121
- demo.launch()
 
1
  from typing import Optional
 
 
 
 
2
  import gradio as gr
3
+ from build_rag import get_context
4
+ from models.gemma import gemma_predict
5
+ from models.gemini import get_gemini_response
6
 
7
 
8
+ def clear():
9
+ return None, None, None
10
 
11
 
12
+ def generate_text(query_text, model_name: Optional[str] = "google/gemma-2b-it"):
13
+ combined_information = get_context(query_text)
14
+ if model_name is None or model_name == "google/gemma-2b-it":
15
+ return gemma_predict(combined_information, model_name)
16
+ if model_name == "gemini-1.0-pro":
17
+ return get_gemini_response(combined_information, model_name, None)
18
+ return "Sorry, something went wrong! Please try again."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
 
21
  examples = [["I'm planning a vacation to France. Can you suggest a one-week itinerary including must-visit places and "
22
+ "local cuisines to try?", "google/gemma-2b-it"],
23
+ # ["I want to explore off-the-beaten-path destinations in Europe, any suggestions?", "gemini-1.0-pro"],
24
+ ["Suggest some cities that can be visited from London and are very rich in history and culture.",
25
+ "google/gemma-2b-it"],
26
  ]
27
 
28
+ with gr.Blocks() as demo:
29
+ gr.HTML("""<center><h1 style='font-size:xx-large;'>πŸ‡ͺπŸ‡Ί Euro City Recommender using Gemini & Gemma πŸ‡ͺπŸ‡Ί</h1><br><h3>Gemini
30
+ & Gemma Sprints 2024 submissions by Ashmi Banerjee. </h3></center> <br><p>We're testing the compatibility of
31
+ Retrieval Augmented Generation (RAG) implementations with Google's <b>Gemma-2b-it</b> & <b>Gemini 1.0 Pro</b>
32
+ models through HuggingFace and VertexAI respectively to generate travel recommendations. This early version (read
33
+ quick and dirty implementation) aims to see if functionalities work smoothly. It relies on Wikipedia abstracts
34
+ from 160 European cities to provide answers to your questions. Please be kind with it as it's a work in progress!
35
+ </p> <br>Google Cloud credits are provided for this project. """)
36
+
37
+ with gr.Group():
38
+ query = gr.Textbox(label="Query", placeholder="Ask for your city recommendation here!")
39
+ model = gr.Dropdown(
40
+ ["google/gemma-2b-it", "gemini-1.0-pro"], label="Model", info="Select your model. Will add more models "
41
+ "later!",
42
+ )
43
+ output = gr.Textbox(label="Generated Results", lines=4)
44
+
45
+ with gr.Group():
46
+ with gr.Row():
47
+ submit_btn = gr.Button("Submit", variant="primary")
48
+ clear_btn = gr.Button("Clear", variant="secondary")
49
+ cancel_btn = gr.Button("Cancel", variant="stop")
50
+ submit_btn.click(generate_text, inputs=[query, model], outputs=[output])
51
+ clear_btn.click(clear, inputs=[], outputs=[query, model, output])
52
+ cancel_btn.click(clear, inputs=[], outputs=[query, model, output])
53
+
54
+ gr.Markdown("## Examples")
55
+ gr.Examples(
56
+ examples, inputs=[query, model], label="Examples", fn=generate_text, outputs=[output],
57
+ cache_examples=True,
58
+ )
59
 
60
  if __name__ == "__main__":
61
+ demo.launch(show_api=False)