capradeepgujaran commited on
Commit
33a1aec
·
verified ·
1 Parent(s): 20ca539

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -40
app.py CHANGED
@@ -4,42 +4,25 @@ import os
4
  from gtts import gTTS
5
  from deep_translator import GoogleTranslator
6
  import logging
7
- from llama_index import VectorStoreIndex, Document, SimpleDirectoryReader
8
- from llama_index.node_parser import SimpleNodeParser
9
- from llama_index.llms import HuggingFaceLLM
10
- from llama_index import ServiceContext, set_global_service_context
11
- import torch
 
 
 
12
 
13
  logging.basicConfig(level=logging.INFO, format='%(asctime)s | %(levelname)s | %(message)s')
14
 
15
- # Initialize the LLM with a smaller context window
16
- try:
17
- llm = HuggingFaceLLM(
18
- context_window=512, # Reduced from 1024
19
- max_new_tokens=256,
20
- generate_kwargs={"temperature": 0.7, "do_sample": False},
21
- tokenizer_name="gpt2",
22
- model_name="gpt2",
23
- device_map="auto",
24
- tokenizer_kwargs={"max_length": 512}, # Reduced from 1024
25
- model_kwargs={"torch_dtype": torch.float32},
26
- )
27
- except ImportError:
28
- # Fallback if Accelerate is not available
29
- llm = HuggingFaceLLM(
30
- context_window=512, # Reduced from 1024
31
- max_new_tokens=256,
32
- generate_kwargs={"temperature": 0.7, "do_sample": False},
33
- tokenizer_name="gpt2",
34
- model_name="gpt2",
35
- tokenizer_kwargs={"max_length": 512}, # Reduced from 1024
36
- model_kwargs={"torch_dtype": torch.float32},
37
- )
38
 
39
- # Initialize the ServiceContext with a chunk size
40
- node_parser = SimpleNodeParser.from_defaults(chunk_size=256) # Adjust chunk size as needed
41
- service_context = ServiceContext.from_defaults(llm=llm, embed_model="local", node_parser=node_parser)
42
- set_global_service_context(service_context)
43
 
44
  # Initialize the index
45
  index = None
@@ -73,28 +56,40 @@ audio_language_dict = {
73
  def index_text(text: str) -> str:
74
  global index
75
  try:
76
- documents = [Document(text=text)]
77
  if index is None:
78
- index = VectorStoreIndex.from_documents(documents)
79
  else:
80
- index.insert(documents[0])
81
  return "Text indexed successfully."
82
  except Exception as e:
83
  logging.error(f"Error in indexing: {str(e)}")
84
  return f"Error indexing text: {str(e)}"
85
 
86
- def chat_with_context(question: str) -> str:
87
  global index
88
  if index is None:
89
  return "Please index some text first."
90
 
91
  try:
92
  query_engine = index.as_query_engine(
93
- similarity_top_k=2, # Adjust as needed
94
  response_mode="compact"
95
  )
96
- response = query_engine.query(question)
97
- return str(response)
 
 
 
 
 
 
 
 
 
 
 
 
98
  except Exception as e:
99
  logging.error(f"Error in chat: {str(e)}")
100
  return f"Error in chat: {str(e)}"
@@ -140,6 +135,11 @@ with gr.Blocks() as iface:
140
  chat_group = gr.Group(visible=False)
141
  with chat_group:
142
  chat_input = gr.Textbox(label="Ask a question about the indexed text")
 
 
 
 
 
143
  chat_button = gr.Button("Ask")
144
  chat_output = gr.Textbox(label="Answer", interactive=False)
145
 
@@ -171,7 +171,7 @@ with gr.Blocks() as iface:
171
  convert_button.click(convert_text, inputs=[text_input, translation_lang_dropdown], outputs=translated_text)
172
  index_button.click(index_text, inputs=[translated_text], outputs=[index_status])
173
  use_chat.change(update_chat_visibility, inputs=[use_chat], outputs=[chat_group])
174
- chat_button.click(chat_with_context, inputs=[chat_input], outputs=[chat_output])
175
 
176
  generate_button.click(
177
  generate_speech,
 
4
  from gtts import gTTS
5
  from deep_translator import GoogleTranslator
6
  import logging
7
+ from llama_index import VectorStoreIndex, Document
8
+ from llama_index.embeddings import HuggingFaceEmbedding
9
+ from llama_index import ServiceContext
10
+ from groq import Groq
11
+ from dotenv import load_dotenv
12
+
13
+ # Load environment variables
14
+ load_dotenv()
15
 
16
  logging.basicConfig(level=logging.INFO, format='%(asctime)s | %(levelname)s | %(message)s')
17
 
18
+ # Initialize Groq client
19
+ groq_client = Groq(api_key=os.getenv("GROQ_API_KEY"))
20
+
21
+ # Initialize the embedding model
22
+ embed_model = HuggingFaceEmbedding(model_name="sentence-transformers/all-MiniLM-L6-v2")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
+ # Initialize the ServiceContext
25
+ service_context = ServiceContext.from_defaults(embed_model=embed_model)
 
 
26
 
27
  # Initialize the index
28
  index = None
 
56
  def index_text(text: str) -> str:
57
  global index
58
  try:
59
+ document = Document(text=text)
60
  if index is None:
61
+ index = VectorStoreIndex.from_documents([document], service_context=service_context)
62
  else:
63
+ index.insert(document)
64
  return "Text indexed successfully."
65
  except Exception as e:
66
  logging.error(f"Error in indexing: {str(e)}")
67
  return f"Error indexing text: {str(e)}"
68
 
69
+ def chat_with_context(question: str, model: str) -> str:
70
  global index
71
  if index is None:
72
  return "Please index some text first."
73
 
74
  try:
75
  query_engine = index.as_query_engine(
76
+ similarity_top_k=2,
77
  response_mode="compact"
78
  )
79
+ context = query_engine.query(question).response
80
+
81
+ prompt = f"Context: {context}\n\nQuestion: {question}\n\nAnswer:"
82
+
83
+ chat_completion = groq_client.chat.completions.create(
84
+ messages=[
85
+ {
86
+ "role": "user",
87
+ "content": prompt,
88
+ }
89
+ ],
90
+ model=model,
91
+ )
92
+ return chat_completion.choices[0].message.content
93
  except Exception as e:
94
  logging.error(f"Error in chat: {str(e)}")
95
  return f"Error in chat: {str(e)}"
 
135
  chat_group = gr.Group(visible=False)
136
  with chat_group:
137
  chat_input = gr.Textbox(label="Ask a question about the indexed text")
138
+ chat_model = gr.Dropdown(
139
+ choices=["llama3-70b-8192", "mixtral-8x7b-32768", "gemma-7b-it"],
140
+ label="Select Chat Model",
141
+ value="llama3-70b-8192"
142
+ )
143
  chat_button = gr.Button("Ask")
144
  chat_output = gr.Textbox(label="Answer", interactive=False)
145
 
 
171
  convert_button.click(convert_text, inputs=[text_input, translation_lang_dropdown], outputs=translated_text)
172
  index_button.click(index_text, inputs=[translated_text], outputs=[index_status])
173
  use_chat.change(update_chat_visibility, inputs=[use_chat], outputs=[chat_group])
174
+ chat_button.click(chat_with_context, inputs=[chat_input, chat_model], outputs=[chat_output])
175
 
176
  generate_button.click(
177
  generate_speech,