gmustafa413 commited on
Commit
6978a4b
·
verified ·
1 Parent(s): f9be209

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -0
app.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import faiss
3
+ from transformers import AutoTokenizer, AutoModel
4
+ from groq import Groq
5
+ from datasets import load_dataset
6
+ import gradio as gr
7
+
8
+ # Initialize the Groq client
9
+ groq_api_key = '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'
10
+ client = Groq(api_key=groq_api_key)
11
+
12
+ # Load dataset
13
+ dataset = load_dataset("midrees2806/7K_Dataset")
14
+
15
+ # Preprocessing function
16
+ def preprocess_data(text):
17
+ tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
18
+ model = AutoModel.from_pretrained('bert-base-uncased')
19
+ inputs = tokenizer(text, return_tensors='pt', max_length=512, truncation=True, padding=True)
20
+ return model(**inputs).last_hidden_state.mean(dim=1).detach().numpy()
21
+
22
+ # Prepare embeddings
23
+ print("Preparing embeddings...")
24
+ train_dataset = dataset['train']
25
+ chunked_embeddings = []
26
+ for data in train_dataset:
27
+ text = data['text']
28
+ chunked_embeddings.append(preprocess_data(text))
29
+ chunked_embeddings = np.vstack(chunked_embeddings)
30
+
31
+ # Initialize FAISS index
32
+ dimension = chunked_embeddings.shape[1]
33
+ index = faiss.IndexFlatL2(dimension)
34
+ index.add(chunked_embeddings)
35
+
36
+ # Groq response function
37
+ def get_groq_response(query):
38
+ chat_completion = client.chat.completions.create(
39
+ messages=[{"role": "user", "content": query}],
40
+ model="llama3-70b-8192",
41
+ )
42
+ return chat_completion.choices[0].message.content
43
+
44
+ # FAISS search function
45
+ def search_in_faiss(query):
46
+ query_embedding = preprocess_data(query)
47
+ distances, indices = index.search(query_embedding, k=5)
48
+ return [dataset['train'][int(idx)]['text'] for idx in indices[0]]
49
+
50
+ # Gradio Chat Interface
51
+ def respond(message, chat_history):
52
+ try:
53
+ faiss_results = search_in_faiss(message)
54
+ model_response = get_groq_response(message)
55
+
56
+ bot_response = "**Relevant Information from Dataset:**\n\n"
57
+ for result in faiss_results:
58
+ bot_response += f"- {result}\n\n"
59
+ bot_response = "**Model Response:**\n\n" + model_response
60
+ #+
61
+ return "", chat_history + [(message, bot_response)]
62
+ except Exception as e:
63
+ print(f"Error: {str(e)}")
64
+ return "", chat_history + [(message, f"Error processing request: {str(e)}")]
65
+
66
+ # Create interface
67
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
68
+ gr.Markdown("# <center>UOE Chatbot</center>")
69
+ gr.Markdown("<center>Ask any question and get answers powered by Groq and FAISS</center>")
70
+
71
+ chatbot = gr.Chatbot(height=500)
72
+ msg = gr.Textbox(label="Type your message here...", placeholder="Ask me anything...")
73
+ clear = gr.Button("Clear")
74
+
75
+ msg.submit(respond, [msg, chatbot], [msg, chatbot])
76
+ clear.click(lambda: None, None, chatbot, queue=False)
77
+
78
+ # Launch application
79
+ if __name__ == "__main__":
80
+ demo.launch()