HipFil98 commited on
Commit
0bf764e
·
verified ·
1 Parent(s): 17196b9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +248 -0
app.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import torch
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
+ from qdrant_client import QdrantClient
6
+ from sentence_transformers import SentenceTransformer
7
+
8
+ # Configure environment variables and paths
9
+ os.environ["TRANSFORMERS_CACHE"] = "/tmp/transformers_cache"
10
+ os.environ["HF_TOKEN"] = os.environ.get("HF_TOKEN", "") # Gets the token from Spaces secrets
11
+
12
+ # Define paths for Qdrant database
13
+ def get_qdrant_path():
14
+ if os.path.exists("/home/user/app"): # We're on HF Spaces
15
+ return "/home/user/app/qdrant_data"
16
+ else: # Local environment
17
+ return "/home/filippo/Scrivania/ELAN_bot/qdrant_data"
18
+
19
+ QDRANT_PATH = get_qdrant_path()
20
+
21
+ # Function to perform vector search using the existing Qdrant database
22
+ def vector_search(query, encoder_model="nomic-ai/nomic-embed-text-v1.5", client_path=None):
23
+ """
24
+ Perform vector search on the Qdrant database and return the relevant context
25
+ """
26
+ if client_path is None:
27
+ client_path = QDRANT_PATH
28
+
29
+ try:
30
+ # Get the encoder and client
31
+ encoder = SentenceTransformer(encoder_model, trust_remote_code=True)
32
+ client = QdrantClient(path=client_path)
33
+
34
+ # Encode the query
35
+ query_vector = encoder.encode(query).tolist()
36
+
37
+ # Perform the search
38
+ hits = client.query_points(
39
+ collection_name="ELAN_docs_pages",
40
+ query=query_vector,
41
+ limit=3,
42
+ ).points
43
+
44
+ # Get the context content
45
+ if hits:
46
+ context = "\n".join([hit.payload['content'] for hit in hits])
47
+ return context
48
+ else:
49
+ return "No relevant documentation found."
50
+ except Exception as e:
51
+ print(f"Vector search error: {str(e)}")
52
+ # Fall back to a message if the search fails
53
+ return f"Unable to perform vector search: {str(e)}"
54
+
55
+ # Function to get the model and tokenizer
56
+ def get_llm():
57
+ """
58
+ Initialize and return the Llama model and tokenizer
59
+ """
60
+ # This loads the model from Hugging Face Hub using your token
61
+ model_id = "meta-llama/Llama-3.2-3B-Instruct"
62
+
63
+ # Load tokenizer
64
+ tokenizer = AutoTokenizer.from_pretrained(
65
+ model_id,
66
+ token=os.environ["HF_TOKEN"]
67
+ )
68
+
69
+ # Load model with memory optimizations
70
+ model = AutoModelForCausalLM.from_pretrained(
71
+ model_id,
72
+ token=os.environ["HF_TOKEN"],
73
+ device_map="auto",
74
+ load_in_8bit=True, # Reduce memory footprint
75
+ torch_dtype=torch.float16
76
+ )
77
+
78
+ return model, tokenizer
79
+
80
+ # Cache model and tokenizer
81
+ _model = None
82
+ _tokenizer = None
83
+
84
+ def get_cached_llm():
85
+ """Get or initialize the model and tokenizer"""
86
+ global _model, _tokenizer
87
+ if _model is None or _tokenizer is None:
88
+ _tokenizer, _model = get_llm()
89
+ return _model, _tokenizer
90
+
91
+ # Function to generate response to ELAN questions
92
+ def generate_response(query):
93
+ """
94
+ Generate a response to a question about ELAN by first searching for relevant context
95
+ """
96
+ # Get context through vector search
97
+ context = vector_search(query)
98
+
99
+ # Get model and tokenizer
100
+ try:
101
+ model, tokenizer = get_cached_llm()
102
+ except Exception as e:
103
+ return f"Error loading model: {str(e)}. Make sure you have set up the HF_TOKEN in your Space secrets and have been granted access to the model."
104
+
105
+ # Create the system message
106
+ system_prompt = "You are a virtual assistant that helps the user in using an annotation software called ELAN. Your task is to summarize information and guide the user in the usage of the software."
107
+
108
+ # Create the user message
109
+ user_prompt = f"""Context: {context}
110
+ question: {query}
111
+
112
+ Use exclusively the information contained in the provided context to reformulate the text in about 120 words.
113
+ take into consideration the provided question as a reference for the formulation of the answer.
114
+ To be more clear and coincise use numbered lists when giving instructions.
115
+ Make sure the reformulation maintains the original meaning.
116
+ In the output, check that there are no grammatical errors. If you find errors, correct them.
117
+ Do not add information that is not present in the original text.
118
+ In the output, never say that you are summarizing the text."""
119
+
120
+ # Format inputs for Llama-3 chat format
121
+ messages = [
122
+ {"role": "system", "content": system_prompt},
123
+ {"role": "user", "content": user_prompt}
124
+ ]
125
+
126
+ try:
127
+ # Convert messages to model input format
128
+ inputs = tokenizer.apply_chat_template(
129
+ messages,
130
+ return_tensors="pt"
131
+ ).to(model.device)
132
+
133
+ # Generate response
134
+ with torch.no_grad():
135
+ output = model.generate(
136
+ inputs,
137
+ max_new_tokens=500,
138
+ temperature=0.1,
139
+ do_sample=True,
140
+ )
141
+
142
+ # Decode and extract only the assistant's response
143
+ full_response = tokenizer.decode(output[0], skip_special_tokens=True)
144
+
145
+ # Extract assistant's response
146
+ # This is a bit tricky with different models, so we'll try a few approaches
147
+ if "assistant" in full_response.lower():
148
+ assistant_response = full_response.split("assistant")[-1].strip()
149
+ else:
150
+ # Just return everything after the user's input
151
+ assistant_response = full_response.split(user_prompt)[-1].strip()
152
+
153
+ return assistant_response
154
+
155
+ except Exception as e:
156
+ return f"Error generating response: {str(e)}"
157
+
158
+ # Function to modify XML code
159
+ def modify_xml(xml_code, instructions):
160
+ """
161
+ Modify XML code according to user instructions
162
+ """
163
+ # Get model and tokenizer
164
+ try:
165
+ model, tokenizer = get_cached_llm()
166
+ except Exception as e:
167
+ return f"Error loading model: {str(e)}. Make sure you have set up the HF_TOKEN in your Space secrets and have been granted access to the model."
168
+
169
+ # Create the system message
170
+ system_prompt = "You are a virtual assistant that helps the user in using an annotation software called ELAN. Your task is to modify the given XML code according to the instructions given by the user."
171
+
172
+ # Create the user message
173
+ user_prompt = f"""XML code: {xml_code}
174
+ Instructions: {instructions}
175
+
176
+ Modify the provided code according to the instructions given above.
177
+ The output should be the modified XML code.
178
+ Don't add any additional information or explanations."""
179
+
180
+ # Format inputs for Llama-3 chat format
181
+ messages = [
182
+ {"role": "system", "content": system_prompt},
183
+ {"role": "user", "content": user_prompt}
184
+ ]
185
+
186
+ try:
187
+ # Convert messages to model input format
188
+ inputs = tokenizer.apply_chat_template(
189
+ messages,
190
+ return_tensors="pt"
191
+ ).to(model.device)
192
+
193
+ # Generate response
194
+ with torch.no_grad():
195
+ output = model.generate(
196
+ inputs,
197
+ max_new_tokens=2000, # Allow for longer XML outputs
198
+ temperature=0.1, # Lower temperature for more deterministic XML generation
199
+ do_sample=False, # No sampling for XML modification
200
+ )
201
+
202
+ # Decode and extract only the assistant's response
203
+ full_response = tokenizer.decode(output[0], skip_special_tokens=True)
204
+
205
+ # Extract assistant's response
206
+ if "assistant" in full_response.lower():
207
+ assistant_response = full_response.split("assistant")[-1].strip()
208
+ else:
209
+ # Just return everything after the user's input
210
+ assistant_response = full_response.split(user_prompt)[-1].strip()
211
+
212
+ return assistant_response
213
+
214
+ except Exception as e:
215
+ return f"Error modifying XML: {str(e)}"
216
+
217
+ # Create the Gradio interface
218
+ with gr.Blocks(title="ELAN Assistant") as demo:
219
+ gr.Markdown("# ELAN Assistant")
220
+ gr.Markdown("This tool helps you with ELAN annotation software. You can ask questions about ELAN or modify XML code.")
221
+
222
+ with gr.Tab("Ask about ELAN"):
223
+ gr.Markdown("Ask any question about how to use ELAN annotation software.")
224
+ with gr.Row():
225
+ question_input = gr.Textbox(label="Your question about ELAN", placeholder="How can I export files in ELAN?", lines=3)
226
+ question_output = gr.Textbox(label="Answer", lines=10)
227
+ question_button = gr.Button("Get Answer")
228
+ question_button.click(fn=generate_response, inputs=question_input, outputs=question_output)
229
+
230
+ with gr.Tab("Modify XML"):
231
+ gr.Markdown("Paste your XML code and provide instructions for modifications.")
232
+ with gr.Row():
233
+ xml_input = gr.Textbox(label="Your XML code", placeholder="<annotation>...</annotation>", lines=10)
234
+ with gr.Row():
235
+ instructions_input = gr.Textbox(label="Modification instructions", placeholder="Change the tier name from 'T1' to 'Speech'", lines=3)
236
+ with gr.Row():
237
+ xml_output = gr.Textbox(label="Modified XML", lines=10)
238
+ xml_button = gr.Button("Modify XML")
239
+ xml_button.click(fn=modify_xml, inputs=[xml_input, instructions_input], outputs=xml_output)
240
+
241
+ gr.Markdown("### About")
242
+ gr.Markdown("""This application uses Meta's Llama-3.2-3B-Instruct model and vector search to provide accurate information about ELAN annotation software.
243
+
244
+ **Note:** This application requires access to the Meta-Llama model. Make sure your Hugging Face account has been granted access to the model and you've added your HF_TOKEN to the Space secrets.""")
245
+
246
+ # Launch the app
247
+ if __name__ == "__main__":
248
+ demo.queue().launch()