| import os |
| from dotenv import load_dotenv |
| from typing import List |
| from langchain_openai import OpenAIEmbeddings, ChatOpenAI |
| from langchain.text_splitter import RecursiveCharacterTextSplitter |
| from langchain_community.vectorstores import Chroma |
| from langchain_core.prompts import ChatPromptTemplate |
| from langchain_core.output_parsers import StrOutputParser |
| from langchain_core.runnables import RunnableParallel |
| import gradio as gr |
|
|
| |
| load_dotenv() |
|
|
| |
| SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) |
| RULES_PATH = os.path.join(SCRIPT_DIR, 'rules.txt') |
|
|
| |
| try: |
| with open(RULES_PATH, 'r') as file: |
| golf_rules = file.read() |
| except FileNotFoundError: |
| print(f"Error: Could not find rules.txt at {RULES_PATH}") |
| golf_rules = "" |
|
|
| if not golf_rules: |
| raise RuntimeError("Failed to load golf rules. Please ensure rules.txt is present in the repository.") |
|
|
| |
| major_splitter = RecursiveCharacterTextSplitter( |
| separators=[r"\n\*\*\*\nRule"], |
| chunk_size=10000, |
| chunk_overlap=0, |
| length_function=len, |
| is_separator_regex=True, |
| ) |
|
|
| |
| detail_splitter = RecursiveCharacterTextSplitter( |
| chunk_size=1000, |
| chunk_overlap=200, |
| length_function=len, |
| ) |
|
|
| |
| major_chunks = major_splitter.split_text(golf_rules) |
| print(f"Created {len(major_chunks)} major rule chunks") |
|
|
| |
| chunks = [] |
| for chunk in major_chunks: |
| if len(chunk) > 1000: |
| sub_chunks = detail_splitter.split_text(chunk) |
| chunks.extend(sub_chunks) |
| else: |
| chunks.append(chunk) |
|
|
| print(f"Created {len(chunks)} total chunks") |
|
|
| |
| embeddings = OpenAIEmbeddings() |
| llm = ChatOpenAI(temperature=0, model="gpt-4o-mini") |
|
|
| |
| vectorstore = Chroma.from_texts( |
| texts=major_chunks, |
| embedding=embeddings, |
| ) |
|
|
| |
| template = """You are a helpful golf rules assistant. Use the following pieces of context to answer the question at the end. |
| If you don't know the answer, just say that you don't know, don't try to make up an answer. |
| You can only answer questions about the rules of golf. If a question is not about golf, kindly remind them that you only are a golf rules assistant. |
| Think step by step and remember to use emojis and cheer the golfer on! |
| |
| Context: {context} |
| |
| Question: {question} |
| |
| Answer:""" |
|
|
| prompt = ChatPromptTemplate.from_template(template) |
|
|
| |
| def format_docs(docs): |
| formatted_docs = [] |
| for i, doc in enumerate(docs, 1): |
| formatted_docs.append(f"[Source {i}]: {doc.page_content}") |
| return "\n\n".join(formatted_docs) |
|
|
| def format_response(response, doctitle): |
| return f"{response}\n\n{'='*50}\nSource used: {doctitle}" |
|
|
| retriever = vectorstore.as_retriever(search_kwargs={"k": 1}) |
|
|
| def rag_chain_with_sources(question): |
| docs = retriever.invoke(question) |
| chain = ( |
| RunnableParallel({ |
| "context": lambda _: format_docs(docs), |
| "question": lambda _: question |
| }) |
| | prompt |
| | llm |
| | StrOutputParser() |
| ) |
| response = chain.invoke({}) |
| return response, docs |
|
|
| |
| def query_golf_rules(question: str) -> str: |
| response, docs = rag_chain_with_sources(question) |
| content_lines = [line for line in docs[0].page_content.split("\n") if line.strip() and line.strip() != "***"] |
| doctitle = content_lines[0] if content_lines else "Unknown Rule" |
| |
| return format_response(response, doctitle) |
|
|
| |
| def gradio_interface(question): |
| return query_golf_rules(question) |
|
|
| demo = gr.Interface( |
| fn=gradio_interface, |
| inputs=gr.Textbox( |
| lines=2, |
| placeholder="What would you like to know?", |
| label="Your Question" |
| ), |
| outputs=gr.Textbox( |
| lines=10, |
| label="GolfGPT Answer" |
| ), |
| title="GolfGPT Rules Assistant", |
| description="Ask questions about golf rules and get accurate answers based on the official rules of golf. The model can make mistakes", |
| examples=[ |
| "What are the rules for taking a drop?", |
| "How do I handle a lost ball?", |
| "Can I repair ball marks on the green?", |
| "What are the rules for playing from a bunker?", |
| "How do I handle an unplayable lie?" |
| ], |
| theme=gr.themes.Soft() |
| ) |
|
|
| |
| if __name__ == "__main__": |
| demo.launch() |