NaderAfshar commited on
Commit
7e69835
·
1 Parent(s): cf05b38

replaced Groq LLM with Cohere Command-r-plus and it works far better

Browse files
Files changed (3) hide show
  1. app.py +7 -4
  2. gen_package_versions.py +11 -0
  3. moduler_interface.py +350 -0
app.py CHANGED
@@ -5,7 +5,8 @@ from llama_index.core.tools import FunctionTool
5
  from llama_index.core.agent import FunctionCallingAgent
6
  from llama_index.core import Settings
7
  from llama_parse import LlamaParse
8
- from llama_index.llms.groq import Groq
 
9
  from llama_index.embeddings.huggingface import HuggingFaceEmbedding
10
  from llama_index.core import (
11
  VectorStoreIndex,
@@ -36,10 +37,12 @@ nest_asyncio.apply()
36
 
37
  load_dotenv()
38
  llama_cloud_api_key = os.getenv("LLAMA_CLOUD_API_KEY")
39
- GROQ_API_KEY = os.getenv("GROQ_API_KEY")
40
  LLAMA_CLOUD_BASE_URL = os.getenv("LLAMA_CLOUD_BASE_URL")
 
 
41
 
42
- global_llm = Groq(api_key=GROQ_API_KEY, model="llama3-70b-8192")
 
43
  global_embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5")
44
  Settings.embed_model = global_embed_model
45
 
@@ -68,7 +71,7 @@ class GenerateQuestionsEvent(Event):
68
 
69
  class RAGWorkflow(Workflow):
70
  storage_dir = "./storage"
71
- llm: Groq
72
  query_engine: VectorStoreIndex
73
 
74
  @step
 
5
  from llama_index.core.agent import FunctionCallingAgent
6
  from llama_index.core import Settings
7
  from llama_parse import LlamaParse
8
+ #from llama_index.llms.groq import Groq
9
+ from llama_index.llms.cohere import Cohere
10
  from llama_index.embeddings.huggingface import HuggingFaceEmbedding
11
  from llama_index.core import (
12
  VectorStoreIndex,
 
37
 
38
  load_dotenv()
39
  llama_cloud_api_key = os.getenv("LLAMA_CLOUD_API_KEY")
 
40
  LLAMA_CLOUD_BASE_URL = os.getenv("LLAMA_CLOUD_BASE_URL")
41
+ #GROQ_API_KEY = os.getenv("GROQ_API_KEY")
42
+ CO_API_KEY = os.getenv("COHERE_API_KEY")
43
 
44
+ #global_llm = Groq(api_key=GROQ_API_KEY, model="llama3-70b-8192")
45
+ global_llm = Cohere(api_key=CO_API_KEY, model="command-r-plus")
46
  global_embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5")
47
  Settings.embed_model = global_embed_model
48
 
 
71
 
72
  class RAGWorkflow(Workflow):
73
  storage_dir = "./storage"
74
+ llm: Cohere
75
  query_engine: VectorStoreIndex
76
 
77
  @step
gen_package_versions.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from importlib.metadata import version, PackageNotFoundError
2
+
3
+ with open("requirements.txt") as f:
4
+ for line in f:
5
+ pkg = line.strip()
6
+ if not pkg or pkg.startswith("#"):
7
+ continue
8
+ try:
9
+ print(f"{pkg}=={version(pkg)}")
10
+ except PackageNotFoundError:
11
+ print(f"{pkg} not installed")
moduler_interface.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from helper import extract_html_content
2
+ from IPython.display import display, HTML
3
+ from llama_index.utils.workflow import draw_all_possible_flows
4
+ from llama_index.core.tools import FunctionTool
5
+ from llama_index.core.agent import FunctionCallingAgent
6
+ from llama_index.core import Settings
7
+ from llama_parse import LlamaParse
8
+ from llama_index.llms.groq import Groq
9
+ from llama_index.embeddings.huggingface import HuggingFaceEmbedding
10
+ from llama_index.core import (
11
+ VectorStoreIndex,
12
+ StorageContext,
13
+ load_index_from_storage
14
+ )
15
+ import nest_asyncio
16
+ from llama_index.core.workflow import InputRequiredEvent, HumanResponseEvent
17
+ from llama_index.core.workflow import (
18
+ StartEvent,
19
+ StopEvent,
20
+ Workflow,
21
+ step,
22
+ Event,
23
+ Context
24
+ )
25
+ from pathlib import Path
26
+ from queue import Queue
27
+ import gradio as gr
28
+ import whisper
29
+ from dotenv import load_dotenv
30
+ import os, json
31
+ import asyncio
32
+
33
+ storage_dir = "./storage"
34
+ application_file = "./data/fake_application_form.pdf"
35
+ nest_asyncio.apply()
36
+
37
+ load_dotenv()
38
+ llama_cloud_api_key = os.getenv("LLAMA_CLOUD_API_KEY")
39
+ GROQ_API_KEY = os.getenv("GROQ_API_KEY")
40
+ LLAMA_CLOUD_BASE_URL = os.getenv("LLAMA_CLOUD_BASE_URL")
41
+
42
+ global_llm = Groq(api_key=GROQ_API_KEY, model="llama3-70b-8192")
43
+ global_embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5")
44
+ Settings.embed_model = global_embed_model
45
+
46
+
47
+ class ParseFormEvent(Event):
48
+ application_form: str
49
+
50
+
51
+ class QueryEvent(Event):
52
+ query: str
53
+ field: str
54
+
55
+
56
+ class ResponseEvent(Event):
57
+ response: str
58
+
59
+
60
+ # new!
61
+ class FeedbackEvent(Event):
62
+ feedback: str
63
+
64
+
65
+ class GenerateQuestionsEvent(Event):
66
+ pass
67
+
68
+
69
+ class RAGWorkflow(Workflow):
70
+ storage_dir = "./storage"
71
+ llm: Groq
72
+ query_engine: VectorStoreIndex
73
+
74
+ @step
75
+ async def set_up(self, ctx: Context, ev: StartEvent) -> ParseFormEvent:
76
+ self.llm = global_llm
77
+ self.storage_dir = storage_dir
78
+ if not ev.resume_file:
79
+ raise ValueError("No resume file provided")
80
+
81
+ if not ev.application_form:
82
+ raise ValueError("No application form provided")
83
+
84
+ # ingest the data and set up the query engine
85
+ if os.path.exists(self.storage_dir):
86
+ # you've already ingested the resume document
87
+ storage_context = StorageContext.from_defaults(persist_dir=self.storage_dir)
88
+ index = load_index_from_storage(storage_context)
89
+ else:
90
+ # parse and load the resume document
91
+ documents = LlamaParse(
92
+ result_type="markdown",
93
+ content_guideline_instruction="This is a resume, gather related facts together and format it as "
94
+ "bullet points with headers"
95
+ ).load_data(ev.resume_file)
96
+ # embed and index the documents
97
+ index = VectorStoreIndex.from_documents(
98
+ documents,
99
+ embed_model=global_embed_model
100
+ )
101
+ index.storage_context.persist(persist_dir=self.storage_dir)
102
+
103
+ # create a query engine
104
+ self.query_engine = index.as_query_engine(llm=self.llm, similarity_top_k=5)
105
+
106
+ # you no longer need a query to be passed in,
107
+ # you'll be generating the queries instead
108
+ # let's pass the application form to a new step to parse it
109
+ return ParseFormEvent(application_form=ev.application_form)
110
+
111
+ # new - separated the form parsing from the question generation
112
+ @step
113
+ async def parse_form(self, ctx: Context, ev: ParseFormEvent) -> GenerateQuestionsEvent:
114
+ parser = LlamaParse(
115
+ result_type="markdown",
116
+ content_guideline_instruction="This is a job application form. Create a list of all the fields "
117
+ "that need to be filled in.",
118
+ formatting_instruction="Return a bulleted list of the fields ONLY."
119
+ )
120
+
121
+ # get the LLM to convert the parsed form into JSON
122
+ result = parser.load_data(ev.application_form)[0]
123
+ raw_json = self.llm.complete(
124
+ f"""
125
+ This is a parsed form.
126
+ Convert it into a JSON object containing only the list
127
+ of fields to be filled in, in the form {{ fields: [...] }}.
128
+ <form>{result.text}</form>.
129
+ Return JSON ONLY, no markdown.
130
+ """)
131
+ fields = json.loads(raw_json.text)["fields"]
132
+
133
+ await ctx.set("fields_to_fill", fields)
134
+ print("\n DEBUG: all fields written to Context >>>>>>>>>>>>>>>>>>>>>>>>>>\n")
135
+
136
+ return GenerateQuestionsEvent()
137
+
138
+ # new - this step can get triggered either by GenerateQuestionsEvent or a FeedbackEvent
139
+ @step
140
+ async def generate_questions(self, ctx: Context, ev: GenerateQuestionsEvent | FeedbackEvent) -> QueryEvent:
141
+
142
+ # get the list of fields to fill in
143
+ fields = await ctx.get("fields_to_fill")
144
+ print("\n DEBUG:all fields Read from Context >>>>>>>>>>>>>>>>>>>>>>>>>>\n")
145
+
146
+ # generate one query for each of the fields, and fire them off
147
+ for field in fields:
148
+ question = f"How would you answer this question about the candidate? <field>{field}</field>"
149
+ # Is there feedback? If so, add it to the query:
150
+ if hasattr(ev, "feedback"):
151
+ question += f"""
152
+ \nWe previously got feedback about how we answered the questions.
153
+ It might not be relevant to this particular field, but here it is:
154
+ <feedback>{ev.feedback}</feedback>
155
+ """
156
+ print("\n question : ", question)
157
+
158
+ ctx.send_event(QueryEvent(
159
+ field=field,
160
+ query=question
161
+ ))
162
+
163
+ # store the number of fields, so we know how many to wait for later
164
+ await ctx.set("total_fields", len(fields))
165
+ print(f"\n DEBUG: total fields from Context : {len(fields)}")
166
+
167
+ return
168
+
169
+ @step
170
+ async def ask_question(self, ctx: Context, ev: QueryEvent) -> ResponseEvent:
171
+ response = self.query_engine.query(
172
+ f"This is a question about the specific resume we have in our database: {ev.query}")
173
+ return ResponseEvent(field=ev.field, response=response.response)
174
+
175
+ # new - we now emit an InputRequiredEvent
176
+ @step
177
+ async def fill_in_application(self, ctx: Context, ev: ResponseEvent) -> InputRequiredEvent:
178
+ # get the total number of fields to wait for
179
+ total_fields = await ctx.get("total_fields")
180
+
181
+ responses = ctx.collect_events(ev, [ResponseEvent] * total_fields)
182
+ if responses is None:
183
+ return None # do nothing if there's nothing to do yet
184
+
185
+ # we've got all the responses!
186
+ responseList = "\n".join("Field: " + r.field + "\n" + "Response: " + r.response for r in responses)
187
+ print("\n DEBUG: got all responses :\n")
188
+
189
+ result = self.llm.complete(f"""
190
+ You are given a list of fields in an application form and responses to
191
+ questions about those fields from a resume. Combine the two into a list of
192
+ fields and succinct, factual answers to fill in those fields.
193
+
194
+ <responses>
195
+ {responseList}
196
+ </responses>
197
+ """)
198
+
199
+ print("\n DEBUG: llm combined the fields and responses from resume")
200
+
201
+ # new! save the result for later
202
+ await ctx.set("filled_form", str(result))
203
+
204
+ print("\n DEBUG: Write all form fields to context. Now will emit InputRequiredEvent")
205
+
206
+ # new! Let's get a human in the loop
207
+ return InputRequiredEvent(
208
+ prefix="How does this look? Give me any feedback you have on any of the answers.",
209
+ result=result
210
+ )
211
+
212
+ # new! Accept the feedback.
213
+ @step
214
+ async def get_feedback(self, ctx: Context, ev: HumanResponseEvent) -> FeedbackEvent | StopEvent:
215
+
216
+ result = self.llm.complete(f"""
217
+ You have received some human feedback on the form-filling task you've done.
218
+ Does everything look good, or is there more work to be done?
219
+ <feedback>
220
+ {ev.response}
221
+ </feedback>
222
+ If everything is fine, respond with just the word 'OKAY'.
223
+ If there's any other feedback, respond with just the word 'FEEDBACK'.
224
+ """)
225
+
226
+ verdict = result.text.strip()
227
+
228
+ print(f"LLM says the verdict was {verdict}")
229
+ if (verdict == "OKAY"):
230
+ return StopEvent(result=await ctx.get("filled_form"))
231
+ else:
232
+ return FeedbackEvent(feedback=ev.response)
233
+
234
+
235
+ def transcribe_speech(filepath):
236
+ if filepath is None:
237
+ gr.Warning("No audio found, please retry.")
238
+
239
+ model = whisper.load_model("base")
240
+ result = model.transcribe(filepath, fp16=False)
241
+
242
+ return result["text"]
243
+
244
+
245
+ # New! Transcription handler.
246
+ class TranscriptionHandler:
247
+
248
+ # we create a queue to hold transcription values
249
+ def __init__(self):
250
+ self.transcription_queue = Queue()
251
+ self.interface = None
252
+ self.log_display = None
253
+
254
+ # every time we record something we put it in the queue
255
+ def store_transcription(self, output):
256
+ self.transcription_queue.put(output)
257
+ return output
258
+
259
+ # This is the same interface and transcription logic as before
260
+ # except it stores the result in a queue instead of a global
261
+ def create_interface(self):
262
+ # Initial Log Display (Textbox with logs)
263
+ log_box = gr.Textbox(
264
+ label="Log Output",
265
+ interactive=False,
266
+ value="Waiting for user interaction...\n",
267
+ height=200
268
+ )
269
+
270
+ # Transcription area that gets activated after form input
271
+ mic_transcribe = gr.Interface(
272
+ fn=lambda x: self.store_transcription(transcribe_speech(x)),
273
+ inputs=gr.Audio(sources=["microphone"], type="filepath"),
274
+ outputs=gr.Textbox(label="Transcription")
275
+ )
276
+
277
+ # Creating a Block interface
278
+ self.interface = gr.Blocks()
279
+ with self.interface:
280
+ with gr.Row():
281
+ self.log_display = log_box # Display log
282
+ with gr.Row():
283
+ # A Tabbed Interface, initially showing the log, then the microphone input
284
+ gr.TabbedInterface([log_box, mic_transcribe], ["Log", "Transcribe Microphone"])
285
+
286
+ return self.interface
287
+
288
+ # Launches the interface with dynamic transition based on events
289
+ async def get_transcription(self):
290
+ self.interface = self.create_interface()
291
+ self.interface.launch(
292
+ share=True, # Remove when running on Hugging Face Spaces
293
+ ssr_mode=False,
294
+ prevent_thread_lock=True
295
+ )
296
+
297
+ # Poll every 1.5 seconds, checking if transcription has been queued
298
+ while True:
299
+ if not self.transcription_queue.empty():
300
+ result = self.transcription_queue.get()
301
+ if self.interface is not None:
302
+ self.interface.close()
303
+ return result
304
+ await asyncio.sleep(1.5)
305
+
306
+ # Update log display dynamically as the workflow progresses
307
+ def update_log(self, message):
308
+ if self.log_display:
309
+ self.log_display.update(value=f"{message}\n")
310
+
311
+
312
+ async def main():
313
+ w = RAGWorkflow(timeout=600, verbose=True)
314
+ handler = w.run(
315
+ resume_file="data/fake_resume.pdf",
316
+ application_form="data/fake_application_form.pdf"
317
+ )
318
+
319
+ print("DEBUG: Starting event stream...")
320
+ async for event in handler.stream_events():
321
+ print(f"DEBUG: Received event type {type(event).__name__}")
322
+ if isinstance(event, InputRequiredEvent):
323
+ print("We've filled in your form! Here are the results:\n")
324
+ print(event.result)
325
+
326
+ # Get transcription
327
+ transcription_handler = TranscriptionHandler()
328
+ response = await transcription_handler.get_transcription()
329
+
330
+ handler.ctx.send_event(
331
+ HumanResponseEvent(
332
+ response=response
333
+ )
334
+ )
335
+ else:
336
+ print("\n handler received event ", event)
337
+
338
+ response = await handler
339
+ print("Agent complete! Here's your final result:")
340
+ print(str(response))
341
+
342
+ # Display of the workflow
343
+ workflow_file = Path(__file__).parent / "workflows" / "form_parsing_workflow.html"
344
+ draw_all_possible_flows(w, filename=str(workflow_file))
345
+ html_content = extract_html_content(str(workflow_file))
346
+ display(HTML(html_content), metadata=dict(isolated=True))
347
+
348
+
349
+ if __name__ == "__main__":
350
+ asyncio.run(main())