Spaces:
Paused
Paused
| import spaces | |
| import torch | |
| import gradio as gr | |
| from transformers import ( | |
| pipeline, | |
| BitsAndBytesConfig, | |
| ) | |
| from duckduckgo_search import DDGS | |
| # ===================================================== | |
| # MODEL SETUP | |
| # ===================================================== | |
| quantization_config = ( | |
| BitsAndBytesConfig(load_in_4bit=True) | |
| if torch.cuda.is_available() | |
| else None | |
| ) | |
| llama3_model_id = "meta-llama/Llama-3.1-8B-Instruct" | |
| llama3_pipe = pipeline( | |
| "text-generation", | |
| model=llama3_model_id, | |
| torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, | |
| device_map="auto", | |
| model_kwargs={"quantization_config": quantization_config}, | |
| ) | |
| print("✅ Model Loaded") | |
| # ===================================================== | |
| # SEARCH (HF SPACES SAFE) | |
| # ===================================================== | |
| def google_search_results(query: str): | |
| """ | |
| Live web search using DuckDuckGo | |
| (Google scraping does NOT work in Spaces) | |
| """ | |
| outputs = [] | |
| try: | |
| with DDGS() as ddgs: | |
| results = ddgs.text(query, max_results=5) | |
| for r in results: | |
| outputs.append(r["body"]) | |
| except Exception as e: | |
| print("Search error:", e) | |
| return outputs | |
| # ===================================================== | |
| # RAG ENRICHMENT | |
| # ===================================================== | |
| def RAG_enrichment(input_question: str): | |
| enrichment = google_search_results(input_question) | |
| print("Search Results:", enrichment) | |
| new_output = ( | |
| input_question | |
| + "\n\nUse the following real-time information to help answer:\n\n" | |
| ) | |
| for info in enrichment: | |
| new_output += info + "\n\n" | |
| return new_output | |
| # ===================================================== | |
| # LLAMA QA | |
| # ===================================================== | |
| def llama_QA(input_question: str, pipe): | |
| prompt = f""" | |
| You are a helpful chatbot assistant. | |
| Answer clearly and concisely. | |
| If real-time info is missing, answer using available knowledge. | |
| Question: | |
| {input_question} | |
| Answer: | |
| """ | |
| outputs = pipe( | |
| prompt, | |
| max_new_tokens=512, | |
| do_sample=True, | |
| temperature=0.7, | |
| ) | |
| response = outputs[0]["generated_text"] | |
| # remove prompt from output | |
| response = response.replace(prompt, "").strip() | |
| return response | |
| # ===================================================== | |
| # GRADIO WRAPPER | |
| # ===================================================== | |
| def gradio_func(input_question): | |
| print("User Question:", input_question) | |
| # Non-RAG | |
| output1 = llama_QA(input_question, llama3_pipe) | |
| # RAG enriched prompt | |
| rag_input = RAG_enrichment(input_question) | |
| # RAG answer | |
| output2 = llama_QA(rag_input, llama3_pipe) | |
| return input_question, rag_input, output1, output2 | |
| # ===================================================== | |
| # UI | |
| # ===================================================== | |
| def create_interface(): | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# 🔎 Llama3 RAG vs Non-RAG Demo") | |
| with gr.Row(): | |
| question_input = gr.Textbox( | |
| label="Enter your question", | |
| value="what day is today in sydney?", | |
| ) | |
| submit_btn = gr.Button("Ask") | |
| with gr.Row(): | |
| input1 = gr.Textbox(label="Non-RAG Input") | |
| input2 = gr.Textbox(label="RAG Enriched Input") | |
| with gr.Row(): | |
| output1 = gr.Textbox(label="Non-RAG Output") | |
| output2 = gr.Textbox(label="RAG Output") | |
| submit_btn.click( | |
| fn=gradio_func, | |
| inputs=[question_input], | |
| outputs=[input1, input2, output1, output2], | |
| ) | |
| return demo | |
| # ===================================================== | |
| # LAUNCH | |
| # ===================================================== | |
| demo = create_interface() | |
| demo.launch() |