Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import pandas as pd | |
| import numpy as np | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| from transformers import AutoTokenizer, AutoModel | |
| import torch | |
| import requests | |
| from datasets import load_dataset | |
| # Set page configuration | |
| st.set_page_config(page_title="Repository Recommender", layout="wide") | |
| # Load model and tokenizer | |
| def load_model(): | |
| model_name = "Salesforce/codet5-small" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| # Check if GPU is available | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = AutoModel.from_pretrained(model_name).to(device) | |
| return tokenizer, model, device | |
| def generate_embedding(text, tokenizer, model, device): | |
| """Generate embeddings for a given text.""" | |
| if not text.strip(): | |
| return np.zeros(512) # Handle empty input gracefully | |
| inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512) | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = model.encoder(**inputs) | |
| return outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy() | |
| # Load dataset | |
| def load_data(_tokenizer, _model, _device): | |
| dataset = load_dataset("frankjosh/filtered_dataset", split="train") | |
| df = pd.DataFrame(dataset).head(500) # Limit to 500 repositories | |
| # Fill missing values to avoid errors | |
| df['docstring'] = df.get('docstring', "").fillna("") | |
| df['summary'] = df.get('summary', "").fillna("") | |
| # Generate embeddings for each row | |
| def compute_embedding(row): | |
| text = f"{row['docstring']} {row['summary']}" | |
| return generate_embedding(text, _tokenizer, _model, _device) | |
| df['embedding'] = df.apply(compute_embedding, axis=1) | |
| return df | |
| def fetch_readme(repo_url): | |
| """Fetch README file from GitHub repository.""" | |
| try: | |
| readme_url = repo_url.rstrip("/") + "/blob/main/README.md" | |
| response = requests.get(readme_url, timeout=10) | |
| if response.status_code == 200: | |
| return response.text | |
| else: | |
| return "README not available." | |
| except requests.exceptions.RequestException as e: | |
| return f"Error fetching README: {e}" | |
| # Main application logic | |
| def main(): | |
| st.title("Repository Recommender System") | |
| st.write("Find Python repositories to learn production-level coding practices.") | |
| # Load resources | |
| tokenizer, model, device = load_model() | |
| with st.spinner("Loading dataset and generating embeddings. This may take a moment..."): | |
| try: | |
| data = load_data(tokenizer, model, device) | |
| except Exception as e: | |
| st.error(f"Error loading dataset: {e}") | |
| return | |
| # Input user query | |
| user_query = st.text_input("Describe your project or learning goal:", | |
| "I am working on a project to recommend music using pandas and numpy.") | |
| if user_query: | |
| with st.spinner("Processing your query..."): | |
| query_embedding = generate_embedding(user_query, tokenizer, model, device) | |
| # Compute similarity | |
| try: | |
| data['similarity'] = data['embedding'].apply( | |
| lambda emb: cosine_similarity([query_embedding], [np.array(emb)])[0][0] | |
| ) | |
| # Filter and sort recommendations | |
| top_recommendations = ( | |
| data.sort_values(by='similarity', ascending=False) | |
| .head(5) | |
| ) | |
| # Display recommendations | |
| st.subheader("Top Recommendations") | |
| for idx, row in top_recommendations.iterrows(): | |
| st.markdown(f"### {row['repo']}") | |
| st.write(f"**Path:** {row['path']}") | |
| st.write(f"**Summary:** {row['summary']}") | |
| st.write(f"**Similarity Score:** {row['similarity']:.2f}") | |
| st.markdown(f"[Repository Link]({row['url']})") | |
| # Fetch and display README | |
| st.subheader("Repository README") | |
| readme_content = fetch_readme(row['url']) | |
| st.code(readme_content) | |
| except Exception as e: | |
| st.error(f"Error computing recommendations: {e}") | |
| if __name__ == "__main__": | |
| main() | |