Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -48,11 +48,9 @@ class Utils:
|
|
| 48 |
if not hits or docstore.empty:
|
| 49 |
return "No relevant documents found."
|
| 50 |
lines = []
|
| 51 |
-
# Ensure we don't try to access indices that are out of bounds
|
| 52 |
valid_hits = [h for h in hits if h[0] < len(docstore)]
|
| 53 |
for i, score in valid_hits[:count]:
|
| 54 |
row = docstore.iloc[i]
|
| 55 |
-
# Ensure 'passage_text' and 'id' columns exist
|
| 56 |
txt = str(row.get("passage_text", "Text not available"))
|
| 57 |
doc_id = row.get("id", "N/A")
|
| 58 |
txt = (txt[:max_chars] + "…") if len(txt) > max_chars else txt
|
|
@@ -142,10 +140,8 @@ class RAGSystem:
|
|
| 142 |
def __init__(self, cfg: Config):
|
| 143 |
self.docstore = pd.read_parquet(cfg.docstore_path)
|
| 144 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 145 |
-
|
| 146 |
self.glot_enc = Glot500Encoder(cfg.glot_model_hf)
|
| 147 |
self.glot_ret = Glot500Retriever(self.glot_enc, self.docstore, cfg.glot_index_out)
|
| 148 |
-
|
| 149 |
txt_enc = FaTextEncoder(cfg.mclip_text_model_hf, device, cfg.max_text_len)
|
| 150 |
self.mclip_ret = TextIndexRetriever(txt_enc, self.docstore, cfg.clip_index_out)
|
| 151 |
self.vision = FaVisionEncoder(cfg.clip_vision_model, device)
|
|
@@ -157,10 +153,9 @@ class RAGSystem:
|
|
| 157 |
# --- 1. LOAD MODELS AND INDEXES (This runs only once when the app starts) ---
|
| 158 |
print("Initializing configuration...")
|
| 159 |
cfg = Config()
|
| 160 |
-
print("Loading RAG system
|
| 161 |
rag_system = RAGSystem(cfg)
|
| 162 |
print("Initializing Gemini model...")
|
| 163 |
-
# Securely get the API key from Hugging Face secrets
|
| 164 |
api_key = os.environ.get("GEMINI_API_KEY")
|
| 165 |
vlm = VLM_GenAI(api_key, model_name="models/gemini-1.5-flash")
|
| 166 |
print("System ready.")
|
|
@@ -169,9 +164,7 @@ print("System ready.")
|
|
| 169 |
def run_rag_query(question_text: str, question_image: Optional[Image.Image]) -> Tuple[str, str]:
|
| 170 |
if not question_text.strip():
|
| 171 |
return "Please ask a question.", ""
|
| 172 |
-
|
| 173 |
context_block = ""
|
| 174 |
-
# Decide which retriever to use based on input
|
| 175 |
if question_image:
|
| 176 |
print("Performing multimodal retrieval...")
|
| 177 |
img_vec = rag_system.vision.encode(question_image)
|
|
@@ -181,16 +174,11 @@ def run_rag_query(question_text: str, question_image: Optional[Image.Image]) ->
|
|
| 181 |
print("Performing text retrieval...")
|
| 182 |
hits = rag_system.glot_ret.topk(question_text, k=cfg.per_option_ctx)
|
| 183 |
context_block = Utils.build_context_block(hits, rag_system.docstore, cfg.per_option_ctx)
|
| 184 |
-
|
| 185 |
-
# --- Augment and Generate ---
|
| 186 |
print("Generating response...")
|
|
|
|
| 187 |
if question_image:
|
| 188 |
prompt = f"با توجه به تصویر و اسناد زیر، به سوال پاسخ دهید.\n\nاسناد:\n{context_block}\n\nسوال: {question_text}"
|
| 189 |
-
else:
|
| 190 |
-
prompt = f"با توجه به اسناد زیر، به سوال پاسخ دهید.\n\nاسناد:\n{context_block}\n\nسوال: {question_text}"
|
| 191 |
-
|
| 192 |
content_parts = [question_image, prompt] if question_image else [prompt]
|
| 193 |
-
|
| 194 |
try:
|
| 195 |
resp = vlm.model.generate_content(
|
| 196 |
content_parts,
|
|
@@ -201,11 +189,51 @@ def run_rag_query(question_text: str, question_image: Optional[Image.Image]) ->
|
|
| 201 |
except Exception as e:
|
| 202 |
answer = f"Error during generation: {e}"
|
| 203 |
print(answer)
|
| 204 |
-
|
| 205 |
return answer, context_block
|
| 206 |
|
| 207 |
# --- 3. CREATE THE GRADIO INTERFACE ---
|
| 208 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
gr.Markdown("# 🍲 Persian Culinary RAG Demo")
|
| 210 |
gr.Markdown("Ask a question about Iranian food, with or without an image, to see the RAG system in action.")
|
| 211 |
|
|
@@ -220,7 +248,7 @@ with gr.Blocks(theme=gr.themes.Soft(), title="Persian Culinary RAG") as demo:
|
|
| 220 |
|
| 221 |
gr.Examples(
|
| 222 |
examples=[
|
| 223 |
-
["در مورد
|
| 224 |
["مواد لازم برای تهیه آش رشته چیست؟", None],
|
| 225 |
],
|
| 226 |
inputs=[text_input, image_input]
|
|
@@ -232,6 +260,5 @@ with gr.Blocks(theme=gr.themes.Soft(), title="Persian Culinary RAG") as demo:
|
|
| 232 |
outputs=[output_answer, output_context]
|
| 233 |
)
|
| 234 |
|
| 235 |
-
# Launch the web server
|
| 236 |
demo.launch()
|
| 237 |
|
|
|
|
| 48 |
if not hits or docstore.empty:
|
| 49 |
return "No relevant documents found."
|
| 50 |
lines = []
|
|
|
|
| 51 |
valid_hits = [h for h in hits if h[0] < len(docstore)]
|
| 52 |
for i, score in valid_hits[:count]:
|
| 53 |
row = docstore.iloc[i]
|
|
|
|
| 54 |
txt = str(row.get("passage_text", "Text not available"))
|
| 55 |
doc_id = row.get("id", "N/A")
|
| 56 |
txt = (txt[:max_chars] + "…") if len(txt) > max_chars else txt
|
|
|
|
| 140 |
def __init__(self, cfg: Config):
|
| 141 |
self.docstore = pd.read_parquet(cfg.docstore_path)
|
| 142 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
| 143 |
self.glot_enc = Glot500Encoder(cfg.glot_model_hf)
|
| 144 |
self.glot_ret = Glot500Retriever(self.glot_enc, self.docstore, cfg.glot_index_out)
|
|
|
|
| 145 |
txt_enc = FaTextEncoder(cfg.mclip_text_model_hf, device, cfg.max_text_len)
|
| 146 |
self.mclip_ret = TextIndexRetriever(txt_enc, self.docstore, cfg.clip_index_out)
|
| 147 |
self.vision = FaVisionEncoder(cfg.clip_vision_model, device)
|
|
|
|
| 153 |
# --- 1. LOAD MODELS AND INDEXES (This runs only once when the app starts) ---
|
| 154 |
print("Initializing configuration...")
|
| 155 |
cfg = Config()
|
| 156 |
+
print("Loading RAG system...")
|
| 157 |
rag_system = RAGSystem(cfg)
|
| 158 |
print("Initializing Gemini model...")
|
|
|
|
| 159 |
api_key = os.environ.get("GEMINI_API_KEY")
|
| 160 |
vlm = VLM_GenAI(api_key, model_name="models/gemini-1.5-flash")
|
| 161 |
print("System ready.")
|
|
|
|
| 164 |
def run_rag_query(question_text: str, question_image: Optional[Image.Image]) -> Tuple[str, str]:
|
| 165 |
if not question_text.strip():
|
| 166 |
return "Please ask a question.", ""
|
|
|
|
| 167 |
context_block = ""
|
|
|
|
| 168 |
if question_image:
|
| 169 |
print("Performing multimodal retrieval...")
|
| 170 |
img_vec = rag_system.vision.encode(question_image)
|
|
|
|
| 174 |
print("Performing text retrieval...")
|
| 175 |
hits = rag_system.glot_ret.topk(question_text, k=cfg.per_option_ctx)
|
| 176 |
context_block = Utils.build_context_block(hits, rag_system.docstore, cfg.per_option_ctx)
|
|
|
|
|
|
|
| 177 |
print("Generating response...")
|
| 178 |
+
prompt = f"با توجه به اسناد زیر، به سوال پاسخ دهید.\n\nاسناد:\n{context_block}\n\nسوال: {question_text}"
|
| 179 |
if question_image:
|
| 180 |
prompt = f"با توجه به تصویر و اسناد زیر، به سوال پاسخ دهید.\n\nاسناد:\n{context_block}\n\nسوال: {question_text}"
|
|
|
|
|
|
|
|
|
|
| 181 |
content_parts = [question_image, prompt] if question_image else [prompt]
|
|
|
|
| 182 |
try:
|
| 183 |
resp = vlm.model.generate_content(
|
| 184 |
content_parts,
|
|
|
|
| 189 |
except Exception as e:
|
| 190 |
answer = f"Error during generation: {e}"
|
| 191 |
print(answer)
|
|
|
|
| 192 |
return answer, context_block
|
| 193 |
|
| 194 |
# --- 3. CREATE THE GRADIO INTERFACE ---
|
| 195 |
+
|
| 196 |
+
# Define your custom CSS for the background image
|
| 197 |
+
custom_css = """
|
| 198 |
+
body {
|
| 199 |
+
/* The URL to your background image in the HF Repo */
|
| 200 |
+
background-image: url('/file=background/back.jpg');
|
| 201 |
+
/* Make the image cover the whole background */
|
| 202 |
+
background-size: cover;
|
| 203 |
+
/* Don't repeat the image */
|
| 204 |
+
background-repeat: no-repeat;
|
| 205 |
+
/* Fix the background image so it doesn't scroll with content */
|
| 206 |
+
background-attachment: fixed;
|
| 207 |
+
/* Center the background image */
|
| 208 |
+
background-position: center;
|
| 209 |
+
color: white; /* Set default text color to white for readability */
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
/* Add a semi-transparent overlay to make text more readable */
|
| 213 |
+
body::before {
|
| 214 |
+
content: "";
|
| 215 |
+
position: absolute;
|
| 216 |
+
top: 0; left: 0; right: 0; bottom: 0;
|
| 217 |
+
background-color: rgba(0, 0, 0, 0.5); /* Black overlay with 50% opacity */
|
| 218 |
+
z-index: -1; /* Place it behind the content */
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
/* Style the main container to have a semi-transparent background */
|
| 222 |
+
.gradio-container {
|
| 223 |
+
background: rgba(0, 0, 0, 0.6) !important; /* Darker, semi-transparent background for the app area */
|
| 224 |
+
border-radius: 20px !important;
|
| 225 |
+
border: 1px solid rgba(255, 255, 255, 0.2);
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
/* Make textboxes semi-transparent */
|
| 229 |
+
textarea, input[type="text"] {
|
| 230 |
+
background-color: rgba(255, 255, 255, 0.1) !important;
|
| 231 |
+
color: white !important;
|
| 232 |
+
border: 1px solid rgba(255, 255, 255, 0.3) !important;
|
| 233 |
+
}
|
| 234 |
+
"""
|
| 235 |
+
|
| 236 |
+
with gr.Blocks(css=custom_css, title="Persian Culinary RAG") as demo:
|
| 237 |
gr.Markdown("# 🍲 Persian Culinary RAG Demo")
|
| 238 |
gr.Markdown("Ask a question about Iranian food, with or without an image, to see the RAG system in action.")
|
| 239 |
|
|
|
|
| 248 |
|
| 249 |
gr.Examples(
|
| 250 |
examples=[
|
| 251 |
+
["در مورد دیزی سنگی توضیح بده", None],
|
| 252 |
["مواد لازم برای تهیه آش رشته چیست؟", None],
|
| 253 |
],
|
| 254 |
inputs=[text_input, image_input]
|
|
|
|
| 260 |
outputs=[output_answer, output_context]
|
| 261 |
)
|
| 262 |
|
|
|
|
| 263 |
demo.launch()
|
| 264 |
|