harrison / app.py
hackergeek98's picture
Update app.py
51b275e verified
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_core.documents import Document
import os
#
# ========================
# CONFIG
# ========================
model_hub_id = "hackergeek98/harrisongpt"
vector_store_hub_id = "hackergeek98/harrisons-rag-vectorstore" # your uploaded vector store
device = "cuda" if torch.cuda.is_available() else "cpu"
# ========================
# LOAD MODEL
# ========================
tokenizer = AutoTokenizer.from_pretrained(model_hub_id)
base_model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium", torch_dtype=torch.float16).to(device)
model = PeftModel.from_pretrained(base_model, model_hub_id)
model.eval()
model.to(device)
# ========================
# LOAD VECTOR STORE
# ========================
# You need to have the vector store locally cloned for Spaces
local_vector_path = "./vectorstore"
if not os.path.exists(local_vector_path):
from huggingface_hub import hf_hub_download, snapshot_download
snapshot_download(vector_store_hub_id, local_dir=local_vector_path)
vectorstore = FAISS.load_local(local_vector_path, HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2", model_kwargs={"device":device}))
retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
# ========================
# HELPER FUNCTIONS
# ========================
def generate_text(prompt, max_length=200, temperature=0.8):
input_ids = tokenizer.encode(prompt + tokenizer.eos_token, return_tensors='pt').to(device)
with torch.no_grad():
outputs = model.generate(
input_ids,
max_length=len(input_ids[0]) + max_length,
temperature=temperature,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)[len(prompt):].strip()
def rag_query(question, chat_history=[]):
relevant_docs = retriever.invoke(question)
context = "\n".join([d.page_content for d in relevant_docs])
conversation_context = "\n".join([f"Q: {q}\nA: {a}" for q,a in chat_history[-3:]]) # last 3 turns
prompt = f"{conversation_context}\nContext:\n{context}\n\nQuestion: {question}\nAnswer:"
response = generate_text(prompt)
return response
# ========================
# GRADIO INTERFACE
# ========================
chat_history = []
def respond(user_message):
response = rag_query(user_message, chat_history)
chat_history.append((user_message, response))
return response
with gr.Blocks() as demo:
gr.Markdown("<h1 style='text-align: center'>HarrisonGPT – Medical RAG Assistant</h1>")
chatbot = gr.Chatbot()
msg = gr.Textbox(label="Ask a medical question:")
clear = gr.Button("Clear Chat")
def user_interaction(message, history):
reply = respond(message)
history = history + [(message, reply)]
return history, history
msg.submit(user_interaction, [msg, chatbot], [chatbot, chatbot])
clear.click(lambda: [], None, chatbot)
demo.launch()