|
|
import time |
|
|
import json |
|
|
from pathlib import Path |
|
|
import gradio as gr |
|
|
|
|
|
from google import genai |
|
|
from google.genai import types |
|
|
|
|
|
DEFAULT_MODEL = "gemini-2.5-flash" |
|
|
|
|
|
def _require_client(client_obj): |
|
|
if client_obj is None: |
|
|
raise RuntimeError("Set your Gemini API key first.") |
|
|
return client_obj |
|
|
|
|
|
def _progress_html(pct: float, text: str) -> str: |
|
|
pct = max(0, min(100, pct)) |
|
|
return ( |
|
|
"<div class=\"progress-card\">" |
|
|
f"<div class=\"progress-head\">{text}</div>" |
|
|
f"<div class=\"pbar\"><div class=\"pbar-fill\" style=\"width:{pct:.0f}%\"></div></div>" |
|
|
f"<div class=\"pbar-foot\">{pct:.0f}%</div>" |
|
|
"</div>" |
|
|
) |
|
|
|
|
|
def ui_set_api_key(api_key: str): |
|
|
api_key = (api_key or "").strip() |
|
|
if not api_key: |
|
|
return None, "❌ API key required." |
|
|
try: |
|
|
client = genai.Client(api_key=api_key) |
|
|
return client, "✅ API key set — good to go." |
|
|
except Exception as e: |
|
|
return None, f"❌ Failed to set API key: {e}" |
|
|
|
|
|
def upload_and_index(client_state, file_obj, progress=gr.Progress(track_tqdm=True)): |
|
|
if client_state is None: |
|
|
yield None, "❌ Set API key first.", _progress_html(0, "Waiting for API key") |
|
|
return |
|
|
if file_obj is None: |
|
|
yield None, "⚠️ Please upload a file to index.", _progress_html(0, "Waiting") |
|
|
return |
|
|
|
|
|
client = _require_client(client_state) |
|
|
|
|
|
|
|
|
store = client.file_search_stores.create(config={"display_name": "uploaded-store"}) |
|
|
store_name = store.name |
|
|
|
|
|
fname = Path(file_obj.name).name |
|
|
|
|
|
progress(0.05, desc=f"Uploading {fname}") |
|
|
yield None, f"Uploading **{fname}** …", _progress_html(5, f"Uploading {fname}") |
|
|
|
|
|
uploaded = client.files.upload(file=str(file_obj.name), config={"display_name": fname}) |
|
|
|
|
|
import_cfg = types.ImportFileConfig(custom_metadata=[]) |
|
|
op = client.file_search_stores.import_file( |
|
|
file_search_store_name=store_name, |
|
|
file_name=uploaded.name, |
|
|
config=import_cfg, |
|
|
) |
|
|
|
|
|
tick = 0 |
|
|
spinner = ["⠋","⠙","⠹","⠸","⠼","⠴","⠦","⠧","⠇","⠏"] |
|
|
while not op.done: |
|
|
time.sleep(0.5) |
|
|
tick += 1 |
|
|
step_pct = min(95, 5 + tick * 3) |
|
|
overall = step_pct |
|
|
progress(min(0.95, 0.05 + 0.03 * tick), |
|
|
desc=f"Indexing {fname} {spinner[tick % len(spinner)]}") |
|
|
yield store_name, f"Indexing **{fname}** …", _progress_html(overall, f"Indexing {fname}") |
|
|
op = client.operations.get(op) |
|
|
|
|
|
yield store_name, f"✅ File indexed into store `{store_name}`", _progress_html(100, "Completed!") |
|
|
|
|
|
def ask(client_state, store_name: str, history, question: str, model_id: str): |
|
|
if client_state is None: |
|
|
return history, "❌ Set API key first." |
|
|
if not store_name: |
|
|
return history, "⚠️ Upload & index a file first." |
|
|
q = (question or "").strip() |
|
|
if not q: |
|
|
return history, "⚠️ Please type a question." |
|
|
|
|
|
client = _require_client(client_state) |
|
|
|
|
|
tool = types.Tool( |
|
|
file_search=types.FileSearch( |
|
|
file_search_store_names=[store_name] |
|
|
) |
|
|
) |
|
|
resp = client.models.generate_content( |
|
|
model=model_id or DEFAULT_MODEL, |
|
|
contents=q, |
|
|
config=types.GenerateContentConfig(tools=[tool]), |
|
|
) |
|
|
answer = resp.text or "No answer." |
|
|
history = history or [] |
|
|
history.append({"role": "user", "content": q}) |
|
|
history.append({"role": "assistant", "content": answer}) |
|
|
|
|
|
return history, history |
|
|
|
|
|
|
|
|
custom_css = """ |
|
|
body {background-color: #f5f7fa;} |
|
|
.gradio-container {max-width: 800px; margin: auto; padding-top: 20px;} |
|
|
.header {text-align: center; margin-bottom: 30px;} |
|
|
h1 {color: #4F46E5; margin-bottom: 5px;} |
|
|
h2 {color: #334155;} |
|
|
.progress-card { background: #fff; border: 1px solid #e5e7eb; border-radius: 8px; padding: 6px 10px; margin: 4px 0; } |
|
|
.pbar {height: 8px; background: #e5e7eb; border-radius: 4px; overflow: hidden; margin: 6px 0;} |
|
|
.pbar-fill {height: 100%; background: #4F46E5; transition: width .3s ease;} |
|
|
.gr-box, .gr-panel { border-radius: 10px !important; } |
|
|
.gr-button.primary { background-color: #4F46E5 !important; color: white !important; } |
|
|
.gr-button { border-radius: 6px !important; } |
|
|
.gr-file, .gr-textbox { border-radius: 6px !important; } |
|
|
""" |
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.HTML(f"<style>{custom_css}</style>") |
|
|
|
|
|
|
|
|
gr.Markdown("# 📄 Gemini File-Chat Demo", elem_classes="header") |
|
|
gr.Markdown("Upload a document and ask questions — get grounded answers from the file.", elem_classes="header") |
|
|
|
|
|
client_state = gr.State(value=None) |
|
|
store_state = gr.State(value="") |
|
|
chat_state = gr.State(value=[]) |
|
|
|
|
|
with gr.Accordion("🔑 API Key (required)", open=True): |
|
|
api_tb = gr.Textbox(label="Gemini API key", placeholder="Paste your API key here…", type="password") |
|
|
api_btn = gr.Button("Set API Key", elem_classes=["primary"]) |
|
|
api_status = gr.Markdown() |
|
|
|
|
|
|
|
|
gr.Markdown("### 1) Upload & Index your file") |
|
|
with gr.Row(): |
|
|
file_uploader = gr.File(label="Choose file to upload", file_types=['.txt', '.pdf', '.docx']) |
|
|
upload_btn = gr.Button("Upload & Index", elem_classes=["primary"]) |
|
|
upload_status = gr.Markdown() |
|
|
|
|
|
gr.Markdown("---") |
|
|
|
|
|
gr.Markdown("### 2) Ask questions about your file") |
|
|
question_tb = gr.Textbox(label="Your question", placeholder="Type a question…") |
|
|
ask_btn = gr.Button("Ask", elem_classes=["primary"]) |
|
|
chatbot = gr.Chatbot() |
|
|
|
|
|
|
|
|
api_btn.click(ui_set_api_key, [api_tb], [client_state, api_status]) |
|
|
|
|
|
upload_btn.click( |
|
|
upload_and_index, |
|
|
[client_state, file_uploader], |
|
|
[store_state, upload_status, upload_btn], |
|
|
show_progress=True |
|
|
) |
|
|
|
|
|
ask_btn.click( |
|
|
ask, |
|
|
[client_state, store_state, chat_state, question_tb, gr.State(DEFAULT_MODEL)], |
|
|
[chatbot, chat_state], |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(theme=gr.themes.Soft()) |
|
|
|