codeboosterstech commited on
Commit
a486fa6
·
verified ·
1 Parent(s): be56221

Create agents.py

Browse files
Files changed (1) hide show
  1. agents.py +131 -0
agents.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Individual agent implementations using Groq API
3
+ """
4
+
5
+ import os
6
+ import json
7
+ import requests
8
+ from groq import Groq
9
+ from typing import Dict, List, Any, Optional
10
+
11
+ class BaseAgent:
12
+ """Base class for all agents"""
13
+
14
+ def __init__(self, model_name: str):
15
+ self.client = Groq(api_key=os.getenv("GROQ_API_KEY"))
16
+ self.model_name = model_name
17
+
18
+ def call_api(self, prompt: str, max_tokens: int = 4000) -> str:
19
+ """Make API call to Groq"""
20
+ try:
21
+ completion = self.client.chat.completions.create(
22
+ model=self.model_name,
23
+ messages=[{"role": "user", "content": prompt}],
24
+ temperature=0.3,
25
+ max_tokens=max_tokens,
26
+ top_p=0.95
27
+ )
28
+ return completion.choices[0].message.content
29
+ except Exception as e:
30
+ print(f"API Error in {self.model_name}: {e}")
31
+ return ""
32
+
33
+ class GeneratorAgent(BaseAgent):
34
+ """Agent 1: Question Paper Generator using Llama 3.1 70B"""
35
+
36
+ def __init__(self):
37
+ super().__init__("llama-3.1-70b-versatile")
38
+
39
+ def generate_question_paper(self, prompt: str) -> Dict[str, Any]:
40
+ """Generate structured question paper"""
41
+ response = self.call_api(prompt)
42
+
43
+ # Extract JSON from response
44
+ try:
45
+ # Find JSON block in response
46
+ start_idx = response.find('{')
47
+ end_idx = response.rfind('}') + 1
48
+ if start_idx != -1 and end_idx != -1:
49
+ json_str = response[start_idx:end_idx]
50
+ return json.loads(json_str)
51
+ except json.JSONDecodeError:
52
+ print("Failed to parse JSON response")
53
+
54
+ return {"error": "Failed to generate valid question paper"}
55
+
56
+ class VerifierAgent(BaseAgent):
57
+ """Agent 2: Quality Verifier using Gemma 2 27B"""
58
+
59
+ def __init__(self):
60
+ super().__init__("gemma2-27b-it")
61
+
62
+ def verify_content(self, prompt: str) -> Dict[str, Any]:
63
+ """Verify and validate generated content"""
64
+ response = self.call_api(prompt)
65
+
66
+ try:
67
+ start_idx = response.find('{')
68
+ end_idx = response.rfind('}') + 1
69
+ if start_idx != -1 and end_idx != -1:
70
+ json_str = response[start_idx:end_idx]
71
+ return json.loads(json_str)
72
+ except json.JSONDecodeError:
73
+ print("Failed to parse verification JSON")
74
+
75
+ return {"status": "error", "corrections": []}
76
+
77
+ class FormatterAgent(BaseAgent):
78
+ """Agent 3: Output Formatter using Mixtral-8x7B"""
79
+
80
+ def __init__(self):
81
+ super().__init__("mixtral-8x7b-32768")
82
+
83
+ def format_final_output(self, prompt: str) -> Dict[str, Any]:
84
+ """Create final structured output"""
85
+ response = self.call_api(prompt, max_tokens=6000)
86
+
87
+ try:
88
+ start_idx = response.find('{')
89
+ end_idx = response.rfind('}') + 1
90
+ if start_idx != -1 and end_idx != -1:
91
+ json_str = response[start_idx:end_idx]
92
+ return json.loads(json_str)
93
+ except json.JSONDecodeError:
94
+ print("Failed to parse formatter JSON")
95
+
96
+ return {"error": "Failed to format final output"}
97
+
98
+ class SearchAgent:
99
+ """Realtime search using SerpAPI"""
100
+
101
+ def __init__(self):
102
+ self.api_key = os.getenv("SERPAPI_KEY")
103
+
104
+ def get_realtime_updates(self, subject: str) -> str:
105
+ """Get recent developments for subject"""
106
+ if not self.api_key:
107
+ return "No API key configured for realtime search"
108
+
109
+ try:
110
+ params = {
111
+ "q": f"{subject} recent developments 2024 2025",
112
+ "api_key": self.api_key,
113
+ "engine": "google",
114
+ "num": 5
115
+ }
116
+
117
+ response = requests.get("https://serpapi.com/search", params=params)
118
+ data = response.json()
119
+
120
+ # Extract snippets from results
121
+ snippets = []
122
+ if "organic_results" in data:
123
+ for result in data["organic_results"][:3]:
124
+ if "snippet" in result:
125
+ snippets.append(result["snippet"])
126
+
127
+ return "\n".join(snippets) if snippets else "No recent updates found"
128
+
129
+ except Exception as e:
130
+ print(f"Search error: {e}")
131
+ return "Error fetching realtime data"