Par-ity_Project / app /golf_swing_rag.py
chenemii's picture
Add Git LFS and update with new logo and improvements
6baea5b
print("=== golf_swing_rag.py Import Debug ===")
# Initialize all variables to None first
pd = None
np = None
faiss = None
SentenceTransformer = None
st = None
openai = None
load_dotenv = None
os = None
json = None
pickle = None
List = None
Dict = None
Tuple = None
re = None
datetime = None
try:
print("Importing pandas...")
import pandas as pd
print("✓ pandas imported successfully")
except ImportError as e:
print(f"✗ pandas import failed: {e}")
try:
print("Importing numpy...")
import numpy as np
print("✓ numpy imported successfully")
except ImportError as e:
print(f"✗ numpy import failed: {e}")
try:
print("Importing faiss...")
import faiss
print("✓ faiss imported successfully")
except ImportError as e:
print(f"✗ faiss import failed: {e}")
try:
print("Importing sentence_transformers...")
from sentence_transformers import SentenceTransformer
print("✓ sentence_transformers imported successfully")
except ImportError as e:
print(f"✗ sentence_transformers import failed: {e}")
print("Trying alternative sentence_transformers import methods...")
# Try importing the package first
try:
import sentence_transformers
print("✓ sentence_transformers package available")
from sentence_transformers import SentenceTransformer
print("✓ SentenceTransformer class imported successfully")
except ImportError as e2:
print(f"✗ Alternative sentence_transformers import failed: {e2}")
# Print more detailed error information
try:
import pkg_resources
installed_packages = [d.project_name for d in pkg_resources.working_set]
if 'sentence-transformers' in installed_packages:
print("✓ sentence-transformers package is installed")
else:
print("✗ sentence-transformers package not found in installed packages")
except:
pass
try:
print("Importing streamlit...")
import streamlit as st
print("✓ streamlit imported successfully")
except ImportError as e:
print(f"✗ streamlit import failed: {e}")
try:
print("Importing openai...")
import openai
print("✓ openai imported successfully")
except ImportError as e:
print(f"✗ openai import failed: {e}")
try:
print("Importing dotenv...")
from dotenv import load_dotenv
print("✓ dotenv imported successfully")
except ImportError as e:
print(f"✗ dotenv import failed: {e}")
try:
print("Importing os...")
import os
print("✓ os imported successfully")
except ImportError as e:
print(f"✗ os import failed: {e}")
try:
print("Importing json...")
import json
print("✓ json imported successfully")
except ImportError as e:
print(f"✗ json import failed: {e}")
try:
print("Importing pickle...")
import pickle
print("✓ pickle imported successfully")
except ImportError as e:
print(f"✗ pickle import failed: {e}")
try:
print("Importing typing...")
from typing import List, Dict, Tuple
print("✓ typing imported successfully")
except ImportError as e:
print(f"✗ typing import failed: {e}")
try:
print("Importing re...")
import re
print("✓ re imported successfully")
except ImportError as e:
print(f"✗ re import failed: {e}")
try:
print("Importing datetime...")
from datetime import datetime
print("✓ datetime imported successfully")
except ImportError as e:
print(f"✗ datetime import failed: {e}")
print("=== End golf_swing_rag.py Import Debug ===")
# Check if critical dependencies are available
missing_deps = []
if pd is None:
missing_deps.append("pandas")
if np is None:
missing_deps.append("numpy")
if faiss is None:
missing_deps.append("faiss")
if SentenceTransformer is None:
missing_deps.append("sentence_transformers")
if st is None:
missing_deps.append("streamlit")
if openai is None:
missing_deps.append("openai")
if os is None:
missing_deps.append("os")
if missing_deps:
print(f"✗ Critical dependencies missing: {missing_deps}")
raise ImportError(f"Missing required dependencies: {', '.join(missing_deps)}")
else:
print("✓ All critical dependencies available")
print("")
# Load environment variables if available
if load_dotenv:
load_dotenv()
class GolfSwingRAG:
def __init__(self, csv_file_path: str = None):
"""Initialize the Golf Swing RAG system"""
print("=== GolfSwingRAG Initialization Debug ===")
print(f"Current working directory: {os.getcwd()}")
print(f"__file__ location: {__file__}")
print(f"Directory of this file: {os.path.dirname(__file__)}")
# Set default CSV path based on current working directory
if csv_file_path is None:
# Try multiple possible locations
possible_paths = [
"golf_swing_articles_complete.csv", # Same directory
"../golf_swing_articles_complete.csv", # Parent directory
"../../golf_swing_articles_complete.csv", # Grandparent directory
"/app/golf_swing_articles_complete.csv", # Absolute path for Hugging Face
"/tmp/golf_swing_articles_complete.csv", # Alternative location
os.path.join(os.path.dirname(__file__), "..", "golf_swing_articles_complete.csv"), # Relative to script
]
csv_file_path = None
for path in possible_paths:
print(f"Checking for CSV at: {path}")
if os.path.exists(path):
csv_file_path = path
print(f"✓ Found CSV at: {path}")
break
else:
print(f"✗ Not found at: {path}")
if csv_file_path is None:
print("✗ CSV file not found in any expected location!")
print(f"Files in current directory: {os.listdir('.')}")
if os.path.exists(".."):
print(f"Files in parent directory: {os.listdir('..')}")
raise FileNotFoundError("golf_swing_articles_complete.csv not found in any expected location")
print(f"Using CSV file: {csv_file_path}")
self.csv_file_path = csv_file_path
print("Initializing SentenceTransformer...")
self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
print("✓ SentenceTransformer initialized")
self.index = None
self.chunks = []
self.metadata = []
self.openai_client = None
# Initialize OpenAI client - prioritize environment variables for Hugging Face deployment
print("Initializing OpenAI client...")
openai_key = None
# Try multiple methods to get the OpenAI API key
# Method 1: Environment variable (primary for HF Spaces)
try:
openai_key = os.getenv("OPENAI_API_KEY", "")
if openai_key:
print("✓ Found OpenAI key in environment variable OPENAI_API_KEY")
else:
print("No key found in environment variable OPENAI_API_KEY")
except Exception as e:
print(f"Error accessing environment variable OPENAI_API_KEY: {e}")
# Method 2: Direct access to OPENAI_API_KEY secret (fallback for Streamlit)
if not openai_key:
try:
openai_key = st.secrets.get("OPENAI_API_KEY", "")
if openai_key:
print("✓ Found OpenAI key in st.secrets['OPENAI_API_KEY']")
else:
print("No key found in st.secrets['OPENAI_API_KEY']")
except Exception as e:
print(f"Error accessing st.secrets['OPENAI_API_KEY']: {e}")
# Method 3: Try nested openai structure (fallback)
if not openai_key:
try:
openai_key = st.secrets.get("openai", {}).get("api_key", "")
if openai_key:
print("✓ Found OpenAI key in nested st.secrets['openai']['api_key']")
else:
print("No key found in st.secrets['openai']['api_key']")
except Exception as e:
print(f"Error accessing nested openai secrets: {e}")
# Initialize client if we found a key
if openai_key and openai_key.startswith("sk-"):
try:
# Simple initialization without extra parameters that might cause conflicts
self.openai_client = openai.OpenAI(api_key=openai_key)
print("✓ OpenAI client initialized successfully")
# Test the client with a simple request to verify it works
try:
# Make a minimal test request
test_response = self.openai_client.chat.completions.create(
model="gpt-4o-mini",
messages=[{"role": "user", "content": "Hi"}],
max_tokens=5
)
print("✓ OpenAI client test successful")
except Exception as test_e:
print(f"⚠️ OpenAI client test failed: {test_e}")
# Try with a different model if gpt-4o-mini fails
try:
test_response = self.openai_client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hi"}],
max_tokens=5
)
print("✓ OpenAI client test successful with gpt-3.5-turbo")
except Exception as test_e2:
print(f"⚠️ OpenAI client test failed with both models: {test_e2}")
self.openai_client = None
except Exception as e:
print(f"✗ Error initializing OpenAI client: {e}")
print(f"Error type: {type(e).__name__}")
# Try alternative initialization approach
try:
print("Trying alternative OpenAI client initialization...")
# Import OpenAI directly to avoid potential conflicts
from openai import OpenAI
self.openai_client = OpenAI(api_key=openai_key)
print("✓ Alternative OpenAI client initialization successful")
# Test the alternative client
try:
test_response = self.openai_client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Test"}],
max_tokens=5
)
print("✓ Alternative OpenAI client test successful")
except Exception as alt_test_e:
print(f"⚠️ Alternative OpenAI client test failed: {alt_test_e}")
self.openai_client = None
except Exception as alt_e:
print(f"✗ Alternative OpenAI client initialization also failed: {alt_e}")
self.openai_client = None
else:
print("✗ No valid OpenAI API key found (should start with 'sk-')")
if openai_key:
print(f"Found key starting with: {openai_key[:10]}...")
self.openai_client = None
print("=== End GolfSwingRAG Initialization Debug ===")
print("")
def load_and_process_data(self):
"""Load CSV data and process it for RAG"""
print("Loading golf swing data...")
# Read CSV file
df = pd.read_csv(self.csv_file_path)
print(f"Loaded {len(df)} articles")
# Process each article
all_chunks = []
all_metadata = []
for idx, row in df.iterrows():
# Parse text chunks if they exist
text_chunks = []
if pd.notna(row['text_chunks']) and row['text_chunks'].strip():
try:
# Parse the text_chunks column (it appears to be a list in string format)
chunks_str = row['text_chunks']
if chunks_str.startswith('[') and chunks_str.endswith(']'):
# Remove brackets and split by quotes
chunks_str = chunks_str[1:-1] # Remove outer brackets
# Split by quote patterns while preserving content
text_chunks = [chunk.strip().strip("'\"") for chunk in chunks_str.split("', '") if chunk.strip()]
if not text_chunks and chunks_str:
text_chunks = [chunks_str.strip().strip("'\"")]
except:
# Fallback: use cleaned_text if text_chunks parsing fails
text_chunks = [row['cleaned_text']] if pd.notna(row['cleaned_text']) else []
# If no chunks, create chunks from cleaned_text or text
if not text_chunks:
text_content = row['cleaned_text'] if pd.notna(row['cleaned_text']) else row['text']
if pd.notna(text_content):
# Split into chunks of ~500 words
words = text_content.split()
chunk_size = 500
text_chunks = [' '.join(words[i:i+chunk_size]) for i in range(0, len(words), chunk_size)]
# Add each chunk with metadata
for chunk_idx, chunk in enumerate(text_chunks):
if chunk and len(chunk.strip()) > 50: # Only process substantial chunks
all_chunks.append(chunk)
all_metadata.append({
'title': row['title'],
'url': row['url'],
'source': row['source'],
'publish_date': row['publish_date'],
'authors': row['authors'],
'chunk_index': chunk_idx,
'article_index': idx
})
self.chunks = all_chunks
self.metadata = all_metadata
print(f"Created {len(all_chunks)} text chunks")
def create_embeddings(self, force_recreate: bool = False):
"""Create embeddings for all text chunks with enhanced error handling"""
try:
# Determine the correct base directory for embeddings files
if os.path.exists("golf_swing_articles_complete.csv"):
# Running from project root
embeddings_file = "golf_swing_embeddings.pkl"
index_file = "golf_swing_index.faiss"
else:
# Running from app directory
embeddings_file = "../golf_swing_embeddings.pkl"
index_file = "../golf_swing_index.faiss"
if not force_recreate and os.path.exists(embeddings_file) and os.path.exists(index_file):
print("Loading existing embeddings...")
try:
with open(embeddings_file, 'rb') as f:
data = pickle.load(f)
self.chunks = data['chunks']
self.metadata = data['metadata']
self.index = faiss.read_index(index_file)
print(f"Loaded {len(self.chunks)} chunks with embeddings")
return
except Exception as e:
print(f"Failed to load existing embeddings: {e}")
print("Will create new embeddings...")
print("Creating embeddings...")
if not self.chunks:
self.load_and_process_data()
# Reduce batch size to prevent memory issues
batch_size = 16 # Reduced from 32
all_embeddings = []
# Add memory management
import gc
for i in range(0, len(self.chunks), batch_size):
try:
batch_chunks = self.chunks[i:i+batch_size]
print(f"Processing batch {i//batch_size + 1}/{(len(self.chunks) + batch_size - 1)//batch_size}")
# Create embeddings with reduced progress bar output
batch_embeddings = self.embedding_model.encode(
batch_chunks,
show_progress_bar=True,
convert_to_numpy=True,
normalize_embeddings=True # Normalize during encoding
)
all_embeddings.append(batch_embeddings)
print(f"Processed {min(i+batch_size, len(self.chunks))}/{len(self.chunks)} chunks")
# Force garbage collection after each batch
gc.collect()
except Exception as e:
print(f"Error processing batch {i//batch_size + 1}: {e}")
# Continue with next batch instead of failing completely
continue
if not all_embeddings:
raise Exception("Failed to create any embeddings")
# Combine all embeddings
print("Combining embeddings...")
embeddings = np.vstack(all_embeddings)
# Create FAISS index with error handling
print("Creating FAISS index...")
dimension = embeddings.shape[1]
# Use simpler FAISS index that's more stable
self.index = faiss.IndexFlatL2(dimension) # L2 distance instead of inner product
# Add embeddings to index
self.index.add(embeddings.astype('float32'))
# Save embeddings and index
print("Saving embeddings...")
try:
with open(embeddings_file, 'wb') as f:
pickle.dump({
'chunks': self.chunks,
'metadata': self.metadata
}, f)
faiss.write_index(self.index, index_file)
print(f"Created and saved embeddings for {len(self.chunks)} chunks")
except Exception as e:
print(f"Failed to save embeddings: {e}")
print("Embeddings created but not saved to disk")
except Exception as e:
print(f"Critical error in create_embeddings: {e}")
print("RAG system will operate in limited mode")
# Set up minimal fallback
self.chunks = self.chunks if hasattr(self, 'chunks') and self.chunks else []
self.metadata = self.metadata if hasattr(self, 'metadata') and self.metadata else []
self.index = None
def search_similar_chunks(self, query: str, top_k: int = 5) -> List[Dict]:
"""Search for similar chunks using semantic similarity with fallback"""
try:
if self.index is None:
print("FAISS index not available, using simple text matching fallback")
return self._fallback_search(query, top_k)
# Create query embedding
query_embedding = self.embedding_model.encode([query], convert_to_numpy=True, normalize_embeddings=True)
# Search in FAISS index (L2 distance, so lower scores are better)
scores, indices = self.index.search(query_embedding.astype('float32'), top_k)
# Convert results to list format
results = []
for i, (score, idx) in enumerate(zip(scores[0], indices[0])):
if idx < len(self.chunks): # Ensure valid index
results.append({
'chunk': self.chunks[idx],
'metadata': self.metadata[idx],
'similarity_score': 1.0 / (1.0 + score) # Convert L2 distance to similarity
})
return results
except Exception as e:
print(f"Error in semantic search: {e}")
print("Falling back to simple text matching")
return self._fallback_search(query, top_k)
def _fallback_search(self, query: str, top_k: int = 5) -> List[Dict]:
"""Simple text-based search fallback when semantic search fails"""
if not self.chunks:
return []
query_lower = query.lower()
query_words = set(query_lower.split())
# Score chunks based on word overlap
scored_chunks = []
for i, chunk in enumerate(self.chunks):
chunk_lower = chunk.lower()
chunk_words = set(chunk_lower.split())
# Calculate simple word overlap score
overlap = len(query_words.intersection(chunk_words))
if overlap > 0:
score = overlap / len(query_words)
scored_chunks.append({
'chunk': chunk,
'metadata': self.metadata[i] if i < len(self.metadata) else {},
'similarity_score': score
})
# Sort by score and return top_k
scored_chunks.sort(key=lambda x: x['similarity_score'], reverse=True)
return scored_chunks[:top_k]
def generate_response(self, query: str, context_chunks: List[Dict]) -> str:
"""Generate response using OpenAI API with context"""
if not self.openai_client:
return self._generate_fallback_response(query, context_chunks)
# Prepare context
context = "\n\n".join([f"Source: {chunk['metadata']['title']}\nContent: {chunk['chunk']}"
for chunk in context_chunks])
# Create system prompt
system_prompt = """You are a golf swing technique expert assistant. You help golfers improve their swing by providing detailed, accurate advice based on professional golf instruction content.
Instructions:
- Answer questions about golf swing technique, mechanics, common problems, and solutions
- Provide specific, actionable advice when possible
- Reference relevant technical concepts when appropriate
- Be encouraging and supportive
- If asked about physical limitations or injuries, recommend consulting with a TPI certified professional
- Always base your answers on the provided context from golf instruction materials
Context from golf instruction database:
{context}"""
user_prompt = f"""Based on the golf instruction content provided, please answer this question about golf swing technique:
Question: {query}
Please provide a helpful, detailed response that addresses the specific question while drawing from the relevant information in the context."""
try:
response = self.openai_client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[
{"role": "system", "content": system_prompt.format(context=context)},
{"role": "user", "content": user_prompt}
],
max_tokens=1000,
temperature=0.7
)
return response.choices[0].message.content
except Exception as e:
print(f"OpenAI API error: {e}")
return self._generate_fallback_response(query, context_chunks)
def _generate_fallback_response(self, query: str, context_chunks: List[Dict]) -> str:
"""Generate a fallback response when OpenAI API is not available"""
if not context_chunks:
return "I couldn't find specific information about that topic in the golf swing database. Could you try rephrasing your question or being more specific?"
# Create a simple response based on the most relevant chunk
best_chunk = context_chunks[0]
chunk_content = best_chunk['chunk']
title = best_chunk['metadata']['title']
response = f"Based on the article '{title}', here's what I found:\n\n"
response += chunk_content[:500] + "..."
response += f"\n\nFor more detailed information, you can refer to the full article: {title}"
return response
def query(self, question: str, top_k: int = 5) -> Dict:
"""Main query method that returns both response and sources"""
# Search for relevant chunks
relevant_chunks = self.search_similar_chunks(question, top_k)
# Generate response
response = self.generate_response(question, relevant_chunks)
return {
'response': response,
'sources': relevant_chunks,
'query': question,
'timestamp': datetime.now().isoformat()
}
def main():
"""Initialize and test the RAG system"""
rag = GolfSwingRAG()
rag.load_and_process_data()
rag.create_embeddings()
# Test query
test_query = "What wrist motion happens during the downswing?"
result = rag.query(test_query)
print(f"Query: {result['query']}")
print(f"Response: {result['response']}")
print(f"Number of sources: {len(result['sources'])}")
if __name__ == "__main__":
main()