File size: 6,033 Bytes
f59fa01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f08be97
 
 
 
 
 
 
f59fa01
 
 
 
b5921ff
f59fa01
 
5b9af15
f59fa01
b5921ff
f59fa01
f08be97
f59fa01
b5921ff
f59fa01
 
5b9af15
f59fa01
 
b5921ff
f59fa01
f08be97
 
b5921ff
f59fa01
 
b5921ff
f59fa01
b5921ff
 
 
f59fa01
b5921ff
 
f59fa01
b5921ff
 
f59fa01
5b9af15
f59fa01
 
 
 
 
b5921ff
 
 
 
 
f59fa01
 
5b9af15
f59fa01
f08be97
f59fa01
f08be97
f59fa01
f08be97
f59fa01
 
f08be97
 
 
f59fa01
 
 
b5921ff
f59fa01
 
 
 
 
 
 
b5921ff
f08be97
f59fa01
 
 
f08be97
f59fa01
5b9af15
f08be97
5b9af15
 
 
 
 
 
 
 
 
 
 
 
f08be97
f59fa01
5b9af15
bdf53fa
f59fa01
5b9af15
 
 
 
f59fa01
46b5770
f59fa01
 
5b9af15
 
 
f59fa01
 
fcb31a0
 
 
 
 
 
f59fa01
 
fcb31a0
5b9af15
 
 
f08be97
f59fa01
fcb31a0
b5921ff
fcb31a0
f59fa01
 
b5921ff
 
 
f59fa01
fcb31a0
f59fa01
b5921ff
46b5770
b5921ff
 
f59fa01
 
5b9af15
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import time
import json
from pathlib import Path
import gradio as gr

from google import genai
from google.genai import types

DEFAULT_MODEL = "gemini-2.5-flash"

def _require_client(client_obj):
    if client_obj is None:
        raise RuntimeError("Set your Gemini API key first.")
    return client_obj

def _progress_html(pct: float, text: str) -> str:
    pct = max(0, min(100, pct))
    return (
        "<div class=\"progress-card\">"
        f"<div class=\"progress-head\">{text}</div>"
        f"<div class=\"pbar\"><div class=\"pbar-fill\" style=\"width:{pct:.0f}%\"></div></div>"
        f"<div class=\"pbar-foot\">{pct:.0f}%</div>"
        "</div>"
    )

def ui_set_api_key(api_key: str):
    api_key = (api_key or "").strip()
    if not api_key:
        return None, "❌ API key required."
    try:
        client = genai.Client(api_key=api_key)
        return client, "✅ API key set — good to go."
    except Exception as e:
        return None, f"❌ Failed to set API key: {e}"

def upload_and_index(client_state, file_obj, progress=gr.Progress(track_tqdm=True)):
    if client_state is None:
        yield None, "❌ Set API key first.", _progress_html(0, "Waiting for API key")
        return
    if file_obj is None:
        yield None, "⚠️ Please upload a file to index.", _progress_html(0, "Waiting")
        return

    client = _require_client(client_state)

    # Create a new store
    store = client.file_search_stores.create(config={"display_name": "uploaded-store"})
    store_name = store.name

    fname = Path(file_obj.name).name

    progress(0.05, desc=f"Uploading {fname}")
    yield None, f"Uploading **{fname}** …", _progress_html(5, f"Uploading {fname}")

    uploaded = client.files.upload(file=str(file_obj.name), config={"display_name": fname})

    import_cfg = types.ImportFileConfig(custom_metadata=[])
    op = client.file_search_stores.import_file(
        file_search_store_name=store_name,
        file_name=uploaded.name,
        config=import_cfg,
    )

    tick = 0
    spinner = ["⠋","⠙","⠹","⠸","⠼","⠴","⠦","⠧","⠇","⠏"]
    while not op.done:
        time.sleep(0.5)
        tick += 1
        step_pct = min(95, 5 + tick * 3)
        overall = step_pct
        progress(min(0.95, 0.05 + 0.03 * tick),
                 desc=f"Indexing {fname} {spinner[tick % len(spinner)]}")
        yield store_name, f"Indexing **{fname}** …", _progress_html(overall, f"Indexing {fname}")
        op = client.operations.get(op)

    yield store_name, f"✅ File indexed into store `{store_name}`", _progress_html(100, "Completed!")

def ask(client_state, store_name: str, history, question: str, model_id: str):
    if client_state is None:
        return history, "❌ Set API key first."
    if not store_name:
        return history, "⚠️ Upload & index a file first."
    q = (question or "").strip()
    if not q:
        return history, "⚠️ Please type a question."

    client = _require_client(client_state)

    tool = types.Tool(
        file_search=types.FileSearch(
            file_search_store_names=[store_name]
        )
    )
    resp = client.models.generate_content(
        model=model_id or DEFAULT_MODEL,
        contents=q,
        config=types.GenerateContentConfig(tools=[tool]),
    )
    answer = resp.text or "No answer."
    history = history or []
    history.append({"role": "user", "content": q})
    history.append({"role": "assistant", "content": answer})

    return history, history

# Custom CSS for nicer look
custom_css = """
body {background-color: #f5f7fa;}
.gradio-container {max-width: 800px; margin: auto; padding-top: 20px;}
.header {text-align: center; margin-bottom: 30px;}
h1 {color: #4F46E5; margin-bottom: 5px;}
h2 {color: #334155;}
.progress-card { background: #fff; border: 1px solid #e5e7eb; border-radius: 8px; padding: 6px 10px; margin: 4px 0; }
.pbar {height: 8px; background: #e5e7eb; border-radius: 4px; overflow: hidden; margin: 6px 0;}
.pbar-fill {height: 100%; background: #4F46E5; transition: width .3s ease;}
.gr-box, .gr-panel { border-radius: 10px !important; }
.gr-button.primary { background-color: #4F46E5 !important; color: white !important; }
.gr-button { border-radius: 6px !important; }
.gr-file, .gr-textbox { border-radius: 6px !important; }
"""

with gr.Blocks() as demo:
    gr.HTML(f"<style>{custom_css}</style>")

    # Header
    gr.Markdown("# 📄 Gemini File-Chat Demo", elem_classes="header")
    gr.Markdown("Upload a document and ask questions — get grounded answers from the file.", elem_classes="header")
    
    client_state = gr.State(value=None)
    store_state = gr.State(value="")
    chat_state = gr.State(value=[])

    with gr.Accordion("🔑 API Key (required)", open=True):
        api_tb = gr.Textbox(label="Gemini API key", placeholder="Paste your API key here…", type="password")
        api_btn = gr.Button("Set API Key", elem_classes=["primary"])
        api_status = gr.Markdown()

    # Upload section
    gr.Markdown("### 1) Upload & Index your file")
    with gr.Row():
        file_uploader = gr.File(label="Choose file to upload", file_types=['.txt', '.pdf', '.docx'])
        upload_btn = gr.Button("Upload & Index", elem_classes=["primary"])
    upload_status = gr.Markdown()

    gr.Markdown("---")
    # Chat section
    gr.Markdown("### 2) Ask questions about your file")
    question_tb = gr.Textbox(label="Your question", placeholder="Type a question…")
    ask_btn = gr.Button("Ask", elem_classes=["primary"])
    chatbot = gr.Chatbot()

    # Define interactions
    api_btn.click(ui_set_api_key, [api_tb], [client_state, api_status])

    upload_btn.click(
        upload_and_index,
        [client_state, file_uploader],
        [store_state, upload_status, upload_btn],
        show_progress=True
    )

    ask_btn.click(
        ask,
        [client_state, store_state, chat_state, question_tb, gr.State(DEFAULT_MODEL)],
        [chatbot, chat_state],
    )

if __name__ == "__main__":
    demo.launch(theme=gr.themes.Soft())