Coding-Agent / agent.py
SakshamSna's picture
added prompt compression agent
09c52fb
import os
import fitz
import faiss
import sqlite3
import numpy as np
import google.generativeai as genai
from dotenv import load_dotenv
from sentence_transformers import SentenceTransformer
from transformers import pipeline # added for summarization
class CodingAgent:
def __init__(self):
load_dotenv()
self.api_key = os.getenv("GEMINI_API_KEY")
if not self.api_key:
raise ValueError("GEMINI_API_KEY not found in environment or .env file.")
genai.configure(api_key=self.api_key)
self.model = genai.GenerativeModel("gemini-1.5-flash")
self.embedder = SentenceTransformer("all-MiniLM-L6-v2")
self.index = faiss.IndexFlatL2(384)
self.docs = []
self.conn = sqlite3.connect("memory.db", check_same_thread=False)
self.conn.execute(
"""CREATE TABLE IF NOT EXISTS memory (id INTEGER PRIMARY KEY, query TEXT, response TEXT)"""
)
self.summarizer = pipeline("summarization", model="facebook/bart-large-cnn") # added
def embed_chunks(self, texts):
return self.embedder.encode(texts, convert_to_numpy=True)
def ingest_file(self, filepath):
chunks = []
if filepath.endswith(".pdf"):
doc = fitz.open(filepath)
for page in doc:
text = page.get_text()
words = text.split()
for i in range(0, len(words), 300):
chunk = " ".join(words[i:i+300])
if len(chunk) > 100:
chunks.append(chunk)
elif filepath.endswith(".py"):
with open(filepath, 'r', encoding='utf-8') as f:
code = f.read()
lines = code.splitlines()
for i in range(0, len(lines), 20):
chunk = "\n".join(lines[i:i+20])
chunks.append(chunk)
else:
return "Unsupported file format."
embeddings = self.embed_chunks(chunks)
self.index.add(np.array(embeddings))
self.docs.extend(chunks)
return f"Added {len(chunks)} chunks."
def retrieve_context(self, query, top_k=2):
if self.index.ntotal == 0:
return ""
query_emb = self.embed_chunks([query])[0]
D, I = self.index.search(np.array([query_emb]), top_k)
return "\n\n".join(self.docs[i] for i in I[0])
def compress_context(self, context, token_limit=2000):
"""Summarizes context if it exceeds token limit."""
if len(context.split()) < token_limit:
return context
summary = self.summarizer(context, max_length=200, min_length=50, do_sample=False)[0]['summary_text']
return summary
def answer(self, query):
# Check memory first
cursor = self.conn.execute(
"SELECT response FROM memory WHERE query = ?", (query,)
)
result = cursor.fetchone()
if result:
return f"[From memory] {result[0]}"
context = self.retrieve_context(query)
compressed_context = self.compress_context(context)
prompt = (
f"You are a helpful coding assistant.\n\n"
f"Context (from uploaded docs):\n{compressed_context}\n\n"
f"User question: {query}\n\n"
f"Answer with code or explanation where needed."
)
response = self.model.generate_content(prompt)
answer = response.text.strip()
self.conn.execute(
"INSERT INTO memory (query, response) VALUES (?, ?)",
(query, answer)
)
self.conn.commit()
return answer
def clear_context(self):
self.conn.execute("DELETE FROM memory")
self.conn.commit()
return "Cleared memory."
def get_stats(self):
cursor = self.conn.execute("SELECT COUNT(*) FROM memory")
count = cursor.fetchone()[0]
return f"Stored answers: {count}\nDocuments: {len(self.docs)}"