Spaces:
Runtime error
Runtime error
| import pandas as pd | |
| import torch | |
| from sentence_transformers import SentenceTransformer | |
| import faiss | |
| from transformers import pipeline | |
| import gradio as gr | |
| # Load the injury data | |
| injury_data = pd.read_csv("Injury_History.csv") | |
| # Initialize an embedding model for creating embeddings of the injury descriptions | |
| embedding_model = SentenceTransformer('all-MiniLM-L6-v2') | |
| # Generate embeddings for each injury record in the dataset | |
| injury_data['embedding'] = injury_data['Notes'].apply(lambda x: embedding_model.encode(x, convert_to_tensor=True)) | |
| # Convert embeddings to numpy arrays for FAISS | |
| embeddings = torch.stack(injury_data['embedding'].to_list()).cpu().numpy() | |
| # Set up a FAISS index for efficient similarity search | |
| index = faiss.IndexFlatL2(embeddings.shape[1]) | |
| index.add(embeddings) | |
| # Define a function to retrieve injuries based on similarity to the query | |
| def retrieve_injuries(query): | |
| # Generate an embedding for the user query | |
| query_embedding = embedding_model.encode(query, convert_to_tensor=True).cpu().numpy() | |
| # Search the FAISS index for the top 3 similar injuries | |
| k = 3 # number of results to retrieve | |
| distances, indices = index.search(query_embedding.reshape(1, -1), k) | |
| # Retrieve the most relevant injury records | |
| results = injury_data.iloc[indices[0]] | |
| return results | |
| # Initialize a text generation model for generating responses | |
| generator = pipeline("text-generation", model="gpt2") | |
| # Define the main function to handle the user query, retrieve relevant injuries, and generate a response | |
| def injury_query(player_query): | |
| # Retrieve relevant injury data | |
| retrieved_injuries = retrieve_injuries(player_query) | |
| # Combine injury details into a context string for generation | |
| injury_details = ". ".join(retrieved_injuries['Notes'].tolist()) | |
| context = f"Injury history: {injury_details}" | |
| # Generate a response based on the retrieved data | |
| response = generator(f"Answer based on data: {context}", max_length=100)[0]['generated_text'] | |
| return response | |
| # Set up the Gradio interface for the app | |
| interface = gr.Interface( | |
| fn=injury_query, | |
| inputs="text", | |
| outputs="text", | |
| title="NBA Player Injury Q&A", | |
| description="Ask about a player's injury history, or inquire about common injuries." | |
| ) | |
| # Launch the Gradio app | |
| interface.launch() | |