|
|
|
|
|
import gradio as gr |
|
|
import requests |
|
|
from typing import List, Dict |
|
|
import time |
|
|
import os |
|
|
|
|
|
|
|
|
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False" |
|
|
|
|
|
|
|
|
API_BASE = "https://sahana31-rag-backend.hf.space" |
|
|
|
|
|
|
|
|
class DwaniClient: |
|
|
def __init__(self, base_url=API_BASE): |
|
|
self.base_url = base_url.rstrip('/') |
|
|
|
|
|
def upload_file(self, file) -> dict: |
|
|
"""Upload a file (HF temp file compatible)""" |
|
|
files = { |
|
|
"file": (file.name.split("/")[-1], open(file.name, "rb"), "application/pdf") |
|
|
} |
|
|
resp = requests.post(f"{self.base_url}/files/upload", files=files, timeout=300) |
|
|
resp.raise_for_status() |
|
|
return resp.json() |
|
|
|
|
|
def get_file_status(self, file_id: str) -> dict: |
|
|
resp = requests.get(f"{self.base_url}/files/{file_id}", timeout=60) |
|
|
resp.raise_for_status() |
|
|
return resp.json() |
|
|
|
|
|
def list_files(self) -> List[dict]: |
|
|
resp = requests.get(f"{self.base_url}/files/", timeout=60) |
|
|
resp.raise_for_status() |
|
|
return resp.json() |
|
|
|
|
|
def chat(self, file_ids: List[str], messages: List[Dict]) -> dict: |
|
|
payload = {"file_ids": file_ids, "messages": messages} |
|
|
resp = requests.post(f"{self.base_url}/chat-with-document", json=payload, timeout=300) |
|
|
resp.raise_for_status() |
|
|
return resp.json() |
|
|
|
|
|
|
|
|
|
|
|
client = DwaniClient() |
|
|
uploaded_files = {} |
|
|
chat_history: List[Dict] = [] |
|
|
selected_files = [] |
|
|
|
|
|
|
|
|
def poll_file_status(file_id: str, max_wait=60): |
|
|
"""Wait for file processing""" |
|
|
for _ in range(max_wait): |
|
|
try: |
|
|
status = client.get_file_status(file_id) |
|
|
if status['status'] == 'completed': |
|
|
return status, True |
|
|
if status['status'] == 'failed': |
|
|
return status, False |
|
|
time.sleep(2) |
|
|
except: |
|
|
time.sleep(2) |
|
|
return {'status': 'timeout'}, False |
|
|
|
|
|
def create_file_list(): |
|
|
"""Display all files with status""" |
|
|
if not uploaded_files: |
|
|
return "No files uploaded" |
|
|
|
|
|
lines = ["**π Your Files:**"] |
|
|
for info in uploaded_files.values(): |
|
|
emoji = { |
|
|
'completed': 'β
', |
|
|
'processing': 'π', |
|
|
'pending': 'β³', |
|
|
'failed': 'β' |
|
|
}.get(info['status'], 'β') |
|
|
lines.append(f"{emoji} {info['filename']} ({info['status']})") |
|
|
return "\n".join(lines) |
|
|
|
|
|
|
|
|
def upload_multiple(files): |
|
|
"""Handle multiple PDF uploads""" |
|
|
if not files: |
|
|
return "No files selected", gr.update(choices=[]), "No files uploaded" |
|
|
|
|
|
global uploaded_files |
|
|
status_msgs = [] |
|
|
|
|
|
for file in files: |
|
|
try: |
|
|
|
|
|
result = client.upload_file(file) |
|
|
file_id = result['file_id'] |
|
|
filename = result['filename'] |
|
|
|
|
|
uploaded_files[file_id] = { |
|
|
'filename': filename, |
|
|
'status': 'pending', |
|
|
'file_id': file_id |
|
|
} |
|
|
|
|
|
|
|
|
status, success = poll_file_status(file_id) |
|
|
|
|
|
if success: |
|
|
uploaded_files[file_id]['status'] = 'completed' |
|
|
status_msgs.append(f"β
{filename} - READY") |
|
|
else: |
|
|
uploaded_files[file_id]['status'] = 'failed' |
|
|
status_msgs.append(f"β {filename} - FAILED") |
|
|
|
|
|
except Exception as e: |
|
|
status_msgs.append(f"β {file.name} - ERROR: {str(e)}") |
|
|
|
|
|
|
|
|
choices = [(info['filename'], info['file_id']) for info in uploaded_files.values() |
|
|
if info['status'] == 'completed'] |
|
|
|
|
|
return "\n".join(status_msgs), gr.update(choices=choices), create_file_list() |
|
|
|
|
|
def refresh_files(): |
|
|
"""Refresh file list from backend""" |
|
|
try: |
|
|
files = client.list_files() |
|
|
global uploaded_files |
|
|
uploaded_files.clear() |
|
|
|
|
|
for f in files: |
|
|
uploaded_files[f['file_id']] = f |
|
|
|
|
|
choices = [(f['filename'], f['file_id']) for f in files if f['status'] == 'completed'] |
|
|
return create_file_list(), gr.update(choices=choices) |
|
|
except: |
|
|
return "Refresh failed", gr.update() |
|
|
|
|
|
def update_selected_files(files): |
|
|
"""Update selected files""" |
|
|
global selected_files |
|
|
selected_files = files or [] |
|
|
return len(selected_files) |
|
|
|
|
|
def format_chat_response(result): |
|
|
"""Format response with sources""" |
|
|
answer = result['answer'] |
|
|
if result.get('sources'): |
|
|
sources = "\n\n**π Sources:**\n" |
|
|
for i, src in enumerate(result['sources'][:5], 1): |
|
|
sources += f"{i}. **{src['filename']}** (Page {src['page']})\n" |
|
|
sources += f" > {src['excerpt'][:120]}...\n\n" |
|
|
return answer + sources |
|
|
return answer |
|
|
|
|
|
def send_message(message, history): |
|
|
"""Send chat message""" |
|
|
global chat_history, selected_files |
|
|
|
|
|
if not message.strip(): |
|
|
return history, "" |
|
|
|
|
|
if not selected_files: |
|
|
return history, "β οΈ Please select files first!" |
|
|
|
|
|
user_message = {"role": "user", "content": message} |
|
|
assistant_message = {"role": "assistant", "content": "Thinking..."} |
|
|
new_history = history + [user_message, assistant_message] |
|
|
|
|
|
try: |
|
|
api_messages = chat_history + [user_message] |
|
|
result = client.chat(selected_files, api_messages) |
|
|
|
|
|
chat_history.append(user_message) |
|
|
chat_history.append({"role": "assistant", "content": result['answer']}) |
|
|
|
|
|
full_response = format_chat_response(result) |
|
|
final_history = history + [user_message, {"role": "assistant", "content": full_response}] |
|
|
return final_history, "" |
|
|
|
|
|
except Exception as e: |
|
|
error_response = {"role": "assistant", "content": f"β Error: {str(e)}"} |
|
|
return new_history[:-1] + [error_response], f"Error: {str(e)}" |
|
|
|
|
|
def clear_chat(): |
|
|
global chat_history |
|
|
chat_history = [] |
|
|
return [] |
|
|
|
|
|
|
|
|
with gr.Blocks(title="Dwani.ai", theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown("# π Dwani.ai - Document Chat") |
|
|
gr.Markdown("**Upload multiple PDFs β Chat with page-accurate citations**") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
gr.Markdown("## π€ Upload Multiple PDFs") |
|
|
file_input = gr.File( |
|
|
label="Select PDFs (Ctrl+Click for multiple)", |
|
|
file_types=[".pdf"], |
|
|
file_count="multiple" |
|
|
) |
|
|
upload_btn = gr.Button("π Upload & Process All", variant="primary") |
|
|
status_output = gr.Markdown("Ready to upload...") |
|
|
refresh_btn = gr.Button("π Refresh List") |
|
|
files_display = gr.Markdown("No files uploaded") |
|
|
|
|
|
with gr.Column(scale=2): |
|
|
gr.Markdown("## π File Manager") |
|
|
file_checkboxes = gr.CheckboxGroup( |
|
|
label="Select documents to chat with:", |
|
|
choices=[], |
|
|
value=[], |
|
|
info="Only completed files appear here" |
|
|
) |
|
|
file_count = gr.Number(label="Files selected", value=0, interactive=False) |
|
|
|
|
|
with gr.Row(): |
|
|
gr.Markdown("## π¬ Chat with Documents") |
|
|
chatbot = gr.Chatbot( |
|
|
label="Ask questions about your documents", |
|
|
height=500, |
|
|
avatar_images=("user.png", "assistant.png") |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
msg_input = gr.Textbox( |
|
|
label="Your question", |
|
|
placeholder="e.g., What are the payment terms? When does the contract expire?", |
|
|
scale=4 |
|
|
) |
|
|
send_btn = gr.Button("Send", variant="primary", scale=1) |
|
|
|
|
|
with gr.Row(): |
|
|
clear_btn = gr.Button("ποΈ New Chat", variant="secondary") |
|
|
|
|
|
|
|
|
upload_btn.click( |
|
|
upload_multiple, |
|
|
inputs=file_input, |
|
|
outputs=[status_output, file_checkboxes, files_display] |
|
|
) |
|
|
|
|
|
refresh_btn.click( |
|
|
refresh_files, |
|
|
outputs=[files_display, file_checkboxes] |
|
|
) |
|
|
|
|
|
file_checkboxes.change( |
|
|
update_selected_files, |
|
|
inputs=file_checkboxes, |
|
|
outputs=file_count |
|
|
) |
|
|
|
|
|
send_btn.click( |
|
|
send_message, |
|
|
inputs=[msg_input, chatbot], |
|
|
outputs=[chatbot, msg_input] |
|
|
) |
|
|
|
|
|
msg_input.submit( |
|
|
send_message, |
|
|
inputs=[msg_input, chatbot], |
|
|
outputs=[chatbot, msg_input] |
|
|
) |
|
|
|
|
|
clear_btn.click( |
|
|
clear_chat, |
|
|
outputs=chatbot |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(show_error=True) |
|
|
|