Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from peft import PeftModel | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from langchain_core.documents import Document | |
| import os | |
| # | |
| # ======================== | |
| # CONFIG | |
| # ======================== | |
| model_hub_id = "hackergeek98/harrisongpt" | |
| vector_store_hub_id = "hackergeek98/harrisons-rag-vectorstore" # your uploaded vector store | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # ======================== | |
| # LOAD MODEL | |
| # ======================== | |
| tokenizer = AutoTokenizer.from_pretrained(model_hub_id) | |
| base_model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium", torch_dtype=torch.float16).to(device) | |
| model = PeftModel.from_pretrained(base_model, model_hub_id) | |
| model.eval() | |
| model.to(device) | |
| # ======================== | |
| # LOAD VECTOR STORE | |
| # ======================== | |
| # You need to have the vector store locally cloned for Spaces | |
| local_vector_path = "./vectorstore" | |
| if not os.path.exists(local_vector_path): | |
| from huggingface_hub import hf_hub_download, snapshot_download | |
| snapshot_download(vector_store_hub_id, local_dir=local_vector_path) | |
| vectorstore = FAISS.load_local(local_vector_path, HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2", model_kwargs={"device":device})) | |
| retriever = vectorstore.as_retriever(search_kwargs={"k": 3}) | |
| # ======================== | |
| # HELPER FUNCTIONS | |
| # ======================== | |
| def generate_text(prompt, max_length=200, temperature=0.8): | |
| input_ids = tokenizer.encode(prompt + tokenizer.eos_token, return_tensors='pt').to(device) | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| input_ids, | |
| max_length=len(input_ids[0]) + max_length, | |
| temperature=temperature, | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id, | |
| eos_token_id=tokenizer.eos_token_id | |
| ) | |
| return tokenizer.decode(outputs[0], skip_special_tokens=True)[len(prompt):].strip() | |
| def rag_query(question, chat_history=[]): | |
| relevant_docs = retriever.invoke(question) | |
| context = "\n".join([d.page_content for d in relevant_docs]) | |
| conversation_context = "\n".join([f"Q: {q}\nA: {a}" for q,a in chat_history[-3:]]) # last 3 turns | |
| prompt = f"{conversation_context}\nContext:\n{context}\n\nQuestion: {question}\nAnswer:" | |
| response = generate_text(prompt) | |
| return response | |
| # ======================== | |
| # GRADIO INTERFACE | |
| # ======================== | |
| chat_history = [] | |
| def respond(user_message): | |
| response = rag_query(user_message, chat_history) | |
| chat_history.append((user_message, response)) | |
| return response | |
| with gr.Blocks() as demo: | |
| gr.Markdown("<h1 style='text-align: center'>HarrisonGPT β Medical RAG Assistant</h1>") | |
| chatbot = gr.Chatbot() | |
| msg = gr.Textbox(label="Ask a medical question:") | |
| clear = gr.Button("Clear Chat") | |
| def user_interaction(message, history): | |
| reply = respond(message) | |
| history = history + [(message, reply)] | |
| return history, history | |
| msg.submit(user_interaction, [msg, chatbot], [chatbot, chatbot]) | |
| clear.click(lambda: [], None, chatbot) | |
| demo.launch() | |