Spaces:
Sleeping
Sleeping
| """ | |
| Individual agent implementations using Groq API | |
| """ | |
| import os | |
| import json | |
| import requests | |
| from groq import Groq | |
| from typing import Dict, List, Any, Optional | |
| class BaseAgent: | |
| """Base class for all agents""" | |
| def __init__(self, model_name: str): | |
| self.client = Groq(api_key=os.getenv("GROQ_API_KEY")) | |
| self.model_name = model_name | |
| def call_api(self, prompt: str, max_tokens: int = 4000) -> str: | |
| """Make API call to Groq""" | |
| try: | |
| completion = self.client.chat.completions.create( | |
| model=self.model_name, | |
| messages=[{"role": "user", "content": prompt}], | |
| temperature=0.3, | |
| max_tokens=max_tokens, | |
| top_p=0.95 | |
| ) | |
| return completion.choices[0].message.content | |
| except Exception as e: | |
| print(f"API Error in {self.model_name}: {e}") | |
| return "" | |
| class GeneratorAgent(BaseAgent): | |
| """Agent 1: Question Paper Generator using Llama 3.1 70B""" | |
| def __init__(self): | |
| super().__init__("llama-3.1-70b-versatile") | |
| def generate_question_paper(self, prompt: str) -> Dict[str, Any]: | |
| """Generate structured question paper""" | |
| response = self.call_api(prompt) | |
| # Extract JSON from response | |
| try: | |
| # Find JSON block in response | |
| start_idx = response.find('{') | |
| end_idx = response.rfind('}') + 1 | |
| if start_idx != -1 and end_idx != -1: | |
| json_str = response[start_idx:end_idx] | |
| return json.loads(json_str) | |
| except json.JSONDecodeError: | |
| print("Failed to parse JSON response") | |
| return {"error": "Failed to generate valid question paper"} | |
| class VerifierAgent(BaseAgent): | |
| """Agent 2: Quality Verifier using Gemma 2 27B""" | |
| def __init__(self): | |
| super().__init__("gemma2-27b-it") | |
| def verify_content(self, prompt: str) -> Dict[str, Any]: | |
| """Verify and validate generated content""" | |
| response = self.call_api(prompt) | |
| try: | |
| start_idx = response.find('{') | |
| end_idx = response.rfind('}') + 1 | |
| if start_idx != -1 and end_idx != -1: | |
| json_str = response[start_idx:end_idx] | |
| return json.loads(json_str) | |
| except json.JSONDecodeError: | |
| print("Failed to parse verification JSON") | |
| return {"status": "error", "corrections": []} | |
| class FormatterAgent(BaseAgent): | |
| """Agent 3: Output Formatter using Mixtral-8x7B""" | |
| def __init__(self): | |
| super().__init__("mixtral-8x7b-32768") | |
| def format_final_output(self, prompt: str) -> Dict[str, Any]: | |
| """Create final structured output""" | |
| response = self.call_api(prompt, max_tokens=6000) | |
| try: | |
| start_idx = response.find('{') | |
| end_idx = response.rfind('}') + 1 | |
| if start_idx != -1 and end_idx != -1: | |
| json_str = response[start_idx:end_idx] | |
| return json.loads(json_str) | |
| except json.JSONDecodeError: | |
| print("Failed to parse formatter JSON") | |
| return {"error": "Failed to format final output"} | |
| class SearchAgent: | |
| """Realtime search using SerpAPI""" | |
| def __init__(self): | |
| self.api_key = os.getenv("SERPAPI_KEY") | |
| def get_realtime_updates(self, subject: str) -> str: | |
| """Get recent developments for subject""" | |
| if not self.api_key: | |
| return "No API key configured for realtime search" | |
| try: | |
| params = { | |
| "q": f"{subject} recent developments 2024 2025", | |
| "api_key": self.api_key, | |
| "engine": "google", | |
| "num": 5 | |
| } | |
| response = requests.get("https://serpapi.com/search", params=params) | |
| data = response.json() | |
| # Extract snippets from results | |
| snippets = [] | |
| if "organic_results" in data: | |
| for result in data["organic_results"][:3]: | |
| if "snippet" in result: | |
| snippets.append(result["snippet"]) | |
| return "\n".join(snippets) if snippets else "No recent updates found" | |
| except Exception as e: | |
| print(f"Search error: {e}") | |
| return "Error fetching realtime data" |