SNS / agents.py
codeboosterstech's picture
Create agents.py
a486fa6 verified
raw
history blame
4.48 kB
"""
Individual agent implementations using Groq API
"""
import os
import json
import requests
from groq import Groq
from typing import Dict, List, Any, Optional
class BaseAgent:
"""Base class for all agents"""
def __init__(self, model_name: str):
self.client = Groq(api_key=os.getenv("GROQ_API_KEY"))
self.model_name = model_name
def call_api(self, prompt: str, max_tokens: int = 4000) -> str:
"""Make API call to Groq"""
try:
completion = self.client.chat.completions.create(
model=self.model_name,
messages=[{"role": "user", "content": prompt}],
temperature=0.3,
max_tokens=max_tokens,
top_p=0.95
)
return completion.choices[0].message.content
except Exception as e:
print(f"API Error in {self.model_name}: {e}")
return ""
class GeneratorAgent(BaseAgent):
"""Agent 1: Question Paper Generator using Llama 3.1 70B"""
def __init__(self):
super().__init__("llama-3.1-70b-versatile")
def generate_question_paper(self, prompt: str) -> Dict[str, Any]:
"""Generate structured question paper"""
response = self.call_api(prompt)
# Extract JSON from response
try:
# Find JSON block in response
start_idx = response.find('{')
end_idx = response.rfind('}') + 1
if start_idx != -1 and end_idx != -1:
json_str = response[start_idx:end_idx]
return json.loads(json_str)
except json.JSONDecodeError:
print("Failed to parse JSON response")
return {"error": "Failed to generate valid question paper"}
class VerifierAgent(BaseAgent):
"""Agent 2: Quality Verifier using Gemma 2 27B"""
def __init__(self):
super().__init__("gemma2-27b-it")
def verify_content(self, prompt: str) -> Dict[str, Any]:
"""Verify and validate generated content"""
response = self.call_api(prompt)
try:
start_idx = response.find('{')
end_idx = response.rfind('}') + 1
if start_idx != -1 and end_idx != -1:
json_str = response[start_idx:end_idx]
return json.loads(json_str)
except json.JSONDecodeError:
print("Failed to parse verification JSON")
return {"status": "error", "corrections": []}
class FormatterAgent(BaseAgent):
"""Agent 3: Output Formatter using Mixtral-8x7B"""
def __init__(self):
super().__init__("mixtral-8x7b-32768")
def format_final_output(self, prompt: str) -> Dict[str, Any]:
"""Create final structured output"""
response = self.call_api(prompt, max_tokens=6000)
try:
start_idx = response.find('{')
end_idx = response.rfind('}') + 1
if start_idx != -1 and end_idx != -1:
json_str = response[start_idx:end_idx]
return json.loads(json_str)
except json.JSONDecodeError:
print("Failed to parse formatter JSON")
return {"error": "Failed to format final output"}
class SearchAgent:
"""Realtime search using SerpAPI"""
def __init__(self):
self.api_key = os.getenv("SERPAPI_KEY")
def get_realtime_updates(self, subject: str) -> str:
"""Get recent developments for subject"""
if not self.api_key:
return "No API key configured for realtime search"
try:
params = {
"q": f"{subject} recent developments 2024 2025",
"api_key": self.api_key,
"engine": "google",
"num": 5
}
response = requests.get("https://serpapi.com/search", params=params)
data = response.json()
# Extract snippets from results
snippets = []
if "organic_results" in data:
for result in data["organic_results"][:3]:
if "snippet" in result:
snippets.append(result["snippet"])
return "\n".join(snippets) if snippets else "No recent updates found"
except Exception as e:
print(f"Search error: {e}")
return "Error fetching realtime data"