Spaces:
Sleeping
Sleeping
| import os | |
| import fitz | |
| import faiss | |
| import sqlite3 | |
| import numpy as np | |
| import google.generativeai as genai | |
| from dotenv import load_dotenv | |
| from sentence_transformers import SentenceTransformer | |
| from transformers import pipeline # added for summarization | |
| class CodingAgent: | |
| def __init__(self): | |
| load_dotenv() | |
| self.api_key = os.getenv("GEMINI_API_KEY") | |
| if not self.api_key: | |
| raise ValueError("GEMINI_API_KEY not found in environment or .env file.") | |
| genai.configure(api_key=self.api_key) | |
| self.model = genai.GenerativeModel("gemini-1.5-flash") | |
| self.embedder = SentenceTransformer("all-MiniLM-L6-v2") | |
| self.index = faiss.IndexFlatL2(384) | |
| self.docs = [] | |
| self.conn = sqlite3.connect("memory.db", check_same_thread=False) | |
| self.conn.execute( | |
| """CREATE TABLE IF NOT EXISTS memory (id INTEGER PRIMARY KEY, query TEXT, response TEXT)""" | |
| ) | |
| self.summarizer = pipeline("summarization", model="facebook/bart-large-cnn") # added | |
| def embed_chunks(self, texts): | |
| return self.embedder.encode(texts, convert_to_numpy=True) | |
| def ingest_file(self, filepath): | |
| chunks = [] | |
| if filepath.endswith(".pdf"): | |
| doc = fitz.open(filepath) | |
| for page in doc: | |
| text = page.get_text() | |
| words = text.split() | |
| for i in range(0, len(words), 300): | |
| chunk = " ".join(words[i:i+300]) | |
| if len(chunk) > 100: | |
| chunks.append(chunk) | |
| elif filepath.endswith(".py"): | |
| with open(filepath, 'r', encoding='utf-8') as f: | |
| code = f.read() | |
| lines = code.splitlines() | |
| for i in range(0, len(lines), 20): | |
| chunk = "\n".join(lines[i:i+20]) | |
| chunks.append(chunk) | |
| else: | |
| return "Unsupported file format." | |
| embeddings = self.embed_chunks(chunks) | |
| self.index.add(np.array(embeddings)) | |
| self.docs.extend(chunks) | |
| return f"Added {len(chunks)} chunks." | |
| def retrieve_context(self, query, top_k=2): | |
| if self.index.ntotal == 0: | |
| return "" | |
| query_emb = self.embed_chunks([query])[0] | |
| D, I = self.index.search(np.array([query_emb]), top_k) | |
| return "\n\n".join(self.docs[i] for i in I[0]) | |
| def compress_context(self, context, token_limit=2000): | |
| """Summarizes context if it exceeds token limit.""" | |
| if len(context.split()) < token_limit: | |
| return context | |
| summary = self.summarizer(context, max_length=200, min_length=50, do_sample=False)[0]['summary_text'] | |
| return summary | |
| def answer(self, query): | |
| # Check memory first | |
| cursor = self.conn.execute( | |
| "SELECT response FROM memory WHERE query = ?", (query,) | |
| ) | |
| result = cursor.fetchone() | |
| if result: | |
| return f"[From memory] {result[0]}" | |
| context = self.retrieve_context(query) | |
| compressed_context = self.compress_context(context) | |
| prompt = ( | |
| f"You are a helpful coding assistant.\n\n" | |
| f"Context (from uploaded docs):\n{compressed_context}\n\n" | |
| f"User question: {query}\n\n" | |
| f"Answer with code or explanation where needed." | |
| ) | |
| response = self.model.generate_content(prompt) | |
| answer = response.text.strip() | |
| self.conn.execute( | |
| "INSERT INTO memory (query, response) VALUES (?, ?)", | |
| (query, answer) | |
| ) | |
| self.conn.commit() | |
| return answer | |
| def clear_context(self): | |
| self.conn.execute("DELETE FROM memory") | |
| self.conn.commit() | |
| return "Cleared memory." | |
| def get_stats(self): | |
| cursor = self.conn.execute("SELECT COUNT(*) FROM memory") | |
| count = cursor.fetchone()[0] | |
| return f"Stored answers: {count}\nDocuments: {len(self.docs)}" | |