nivakaran commited on
Commit
7ebc997
·
verified ·
1 Parent(s): 2e9d4b7

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. requirements.txt +1 -0
  2. src/llm/groq_llm.py +184 -0
  3. src/rag/pipeline.py +5 -5
requirements.txt CHANGED
@@ -24,6 +24,7 @@ gradio>=4.0.0
24
  # API
25
  uvicorn>=0.27.0
26
  python-multipart>=0.0.6
 
27
 
28
  # Dev
29
  ipykernel
 
24
  # API
25
  uvicorn>=0.27.0
26
  python-multipart>=0.0.6
27
+ groq>=0.4.0
28
 
29
  # Dev
30
  ipykernel
src/llm/groq_llm.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Groq LLM client with local fallback for FreeRAG."""
2
+
3
+ import logging
4
+ import os
5
+ from typing import Optional
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+ # Groq API configuration
10
+ GROQ_API_KEY = os.environ.get("GROQ_API_KEY", "")
11
+ GROQ_MODEL = "llama-3.1-8b-instant" # Fast, free model on Groq
12
+
13
+
14
+ class GroqLLM:
15
+ """Groq-based LLM with local model fallback.
16
+
17
+ Uses Groq API for fast inference, falls back to local Phi-3
18
+ if Groq is unavailable or rate limited.
19
+ """
20
+
21
+ def __init__(self):
22
+ """Initialize Groq client."""
23
+ self._groq_client = None
24
+ self._local_model = None
25
+ self._groq_available = bool(GROQ_API_KEY)
26
+
27
+ if self._groq_available:
28
+ try:
29
+ from groq import Groq
30
+ self._groq_client = Groq(api_key=GROQ_API_KEY)
31
+ logger.info("✅ Groq client initialized successfully")
32
+ except Exception as e:
33
+ logger.warning(f"⚠️ Groq initialization failed: {e}")
34
+ self._groq_available = False
35
+ else:
36
+ logger.info("📍 No GROQ_API_KEY found, using local model only")
37
+
38
+ @property
39
+ def local_model(self):
40
+ """Lazy load the local fallback model."""
41
+ if self._local_model is None:
42
+ from src.llm.phi_model import PhiModel
43
+ from src.config import ModelConfig
44
+ logger.info("🔄 Loading local fallback model...")
45
+ self._local_model = PhiModel(ModelConfig())
46
+ return self._local_model
47
+
48
+ def generate(
49
+ self,
50
+ prompt: str,
51
+ system_prompt: Optional[str] = None,
52
+ max_tokens: int = 256,
53
+ temperature: float = 0.7
54
+ ) -> str:
55
+ """Generate response using Groq with local fallback.
56
+
57
+ Args:
58
+ prompt: User prompt/question.
59
+ system_prompt: Optional system prompt.
60
+ max_tokens: Maximum tokens to generate.
61
+ temperature: Sampling temperature.
62
+
63
+ Returns:
64
+ Generated response string.
65
+ """
66
+ # Try Groq first if available
67
+ if self._groq_available and self._groq_client:
68
+ try:
69
+ response = self._call_groq(prompt, system_prompt, max_tokens, temperature)
70
+ if response:
71
+ return response
72
+ except Exception as e:
73
+ logger.warning(f"⚠️ Groq API error, falling back to local: {e}")
74
+
75
+ # Fallback to local model
76
+ logger.info("🔄 Using local model for generation")
77
+ return self._call_local(prompt, system_prompt, max_tokens)
78
+
79
+ def _call_groq(
80
+ self,
81
+ prompt: str,
82
+ system_prompt: Optional[str],
83
+ max_tokens: int,
84
+ temperature: float
85
+ ) -> str:
86
+ """Call Groq API."""
87
+ messages = []
88
+
89
+ if system_prompt:
90
+ messages.append({"role": "system", "content": system_prompt})
91
+
92
+ messages.append({"role": "user", "content": prompt})
93
+
94
+ response = self._groq_client.chat.completions.create(
95
+ model=GROQ_MODEL,
96
+ messages=messages,
97
+ max_tokens=max_tokens,
98
+ temperature=temperature,
99
+ stream=False
100
+ )
101
+
102
+ result = response.choices[0].message.content
103
+ logger.info(f"✅ Groq response generated ({len(result)} chars)")
104
+ return result
105
+
106
+ def _call_local(
107
+ self,
108
+ prompt: str,
109
+ system_prompt: Optional[str],
110
+ max_tokens: int
111
+ ) -> str:
112
+ """Call local model."""
113
+ messages = []
114
+
115
+ if system_prompt:
116
+ messages.append({"role": "system", "content": system_prompt})
117
+
118
+ messages.append({"role": "user", "content": prompt})
119
+
120
+ return self.local_model.chat(messages, max_tokens=max_tokens)
121
+
122
+ def chat_with_context(
123
+ self,
124
+ query: str,
125
+ context: str,
126
+ system_prompt: Optional[str] = None,
127
+ conversation_history: Optional[str] = None
128
+ ) -> str:
129
+ """Generate response with RAG context.
130
+
131
+ Args:
132
+ query: User's question.
133
+ context: Retrieved context from documents.
134
+ system_prompt: Optional system prompt.
135
+ conversation_history: Optional conversation history.
136
+
137
+ Returns:
138
+ Generated response.
139
+ """
140
+ if system_prompt is None:
141
+ system_prompt = (
142
+ "Your name is Dragon. Always speak in only ENGLISH not any other language. "
143
+ "You are a friendly and helpful assistant having a natural conversation. "
144
+ "Answer questions based on the provided document context. "
145
+ "Be conversational, warm, and helpful - like talking to a knowledgeable friend. "
146
+ "If you can find relevant information, explain it clearly and naturally. "
147
+ "If the context doesn't have enough information, kindly say so. "
148
+ "Keep your responses concise but friendly."
149
+ )
150
+
151
+ # Handle empty context
152
+ if not context or not context.strip():
153
+ context = "No relevant documents found."
154
+
155
+ # Build message with optional history
156
+ history_section = ""
157
+ if conversation_history and conversation_history.strip():
158
+ history_section = f"""Previous conversation:
159
+ {conversation_history}
160
+
161
+ ---
162
+ """
163
+
164
+ prompt = f"""{history_section}Here's some information from the documents:
165
+
166
+ {context}
167
+
168
+ User's current question: {query}
169
+
170
+ Please respond naturally and helpfully, considering the conversation context:"""
171
+
172
+ return self.generate(prompt, system_prompt=system_prompt)
173
+
174
+
175
+ # Global Groq LLM instance
176
+ _groq_llm: Optional[GroqLLM] = None
177
+
178
+
179
+ def get_groq_llm() -> GroqLLM:
180
+ """Get or create the global Groq LLM instance."""
181
+ global _groq_llm
182
+ if _groq_llm is None:
183
+ _groq_llm = GroqLLM()
184
+ return _groq_llm
src/rag/pipeline.py CHANGED
@@ -3,7 +3,6 @@
3
  from typing import Optional, Dict, Any
4
 
5
  from src.config import Config
6
- from src.llm.phi_model import PhiModel
7
  from src.embeddings.sentence_embeddings import EmbeddingModel
8
  from src.document_loader.loader import DocumentLoader
9
  from src.document_loader.splitter import TextSplitter
@@ -24,7 +23,7 @@ class RAGPipeline:
24
  self.config.ensure_directories()
25
 
26
  # Initialize components lazily
27
- self._llm: Optional[PhiModel] = None
28
  self._embedding_model: Optional[EmbeddingModel] = None
29
  self._vector_store: Optional[VectorStore] = None
30
  self._retriever: Optional[Retriever] = None
@@ -32,10 +31,11 @@ class RAGPipeline:
32
  self._text_splitter: Optional[TextSplitter] = None
33
 
34
  @property
35
- def llm(self) -> PhiModel:
36
- """Get LLM instance."""
37
  if self._llm is None:
38
- self._llm = PhiModel(self.config.model)
 
39
  return self._llm
40
 
41
  @property
 
3
  from typing import Optional, Dict, Any
4
 
5
  from src.config import Config
 
6
  from src.embeddings.sentence_embeddings import EmbeddingModel
7
  from src.document_loader.loader import DocumentLoader
8
  from src.document_loader.splitter import TextSplitter
 
23
  self.config.ensure_directories()
24
 
25
  # Initialize components lazily
26
+ self._llm = None # Will be GroqLLM with fallback
27
  self._embedding_model: Optional[EmbeddingModel] = None
28
  self._vector_store: Optional[VectorStore] = None
29
  self._retriever: Optional[Retriever] = None
 
31
  self._text_splitter: Optional[TextSplitter] = None
32
 
33
  @property
34
+ def llm(self):
35
+ """Get LLM instance (Groq with local fallback)."""
36
  if self._llm is None:
37
+ from src.llm.groq_llm import get_groq_llm
38
+ self._llm = get_groq_llm()
39
  return self._llm
40
 
41
  @property