Mohamed284 commited on
Commit
cf569da
·
verified ·
1 Parent(s): ba2ae67

trying to load the models from HF

Browse files
Files changed (1) hide show
  1. app.py +27 -36
app.py CHANGED
@@ -6,7 +6,6 @@ import os
6
  import pickle
7
  from typing import List, Tuple, Optional
8
  import gradio as gr
9
- from openai import OpenAI
10
  from functools import lru_cache
11
  from tenacity import retry, stop_after_attempt, wait_exponential
12
  from langchain_community.retrievers import BM25Retriever
@@ -15,25 +14,31 @@ from langchain_core.embeddings import Embeddings
15
  from langchain_core.documents import Document
16
  from collections import defaultdict
17
  import hashlib
18
- from tqdm import tqdm
 
 
 
19
 
20
  # --- Configuration ---
21
  FAISS_INDEX_PATH = "faiss_index"
22
  BM25_INDEX_PATH = "bm25_index.pkl"
23
  CACHE_VERSION = "v1" # Increment when data format changes
24
- embedding_model = "e5-mistral-7b-instruct"
25
- generation_model = "meta-llama-3-70b-instruct"
26
  data_file_name = "AskNatureNet_data_enhanced.json"
27
- API_CONFIG = {
28
- "api_key": os.getenv("API_KEY"),
29
- "base_url": "https://chat-ai.academiccloud.de/v1"
30
- }
31
  CHUNK_SIZE = 800
32
  OVERLAP = 200
33
  EMBEDDING_BATCH_SIZE = 32 # Batch size for embedding API calls
34
 
35
- # Initialize clients
36
- client = OpenAI(**API_CONFIG)
 
 
 
 
 
 
 
37
  logging.basicConfig(level=logging.INFO)
38
  logger = logging.getLogger(__name__)
39
 
@@ -52,12 +57,7 @@ class MistralEmbeddings(Embeddings):
52
  # Process in batches with progress tracking
53
  for i in tqdm(range(0, len(texts), EMBEDDING_BATCH_SIZE), desc="Embedding Progress"):
54
  batch = texts[i:i + EMBEDDING_BATCH_SIZE]
55
- response = client.embeddings.create(
56
- input=batch,
57
- model=embedding_model,
58
- encoding_format="float"
59
- )
60
- embeddings.extend([e.embedding for e in response.data])
61
  return embeddings
62
  except Exception as e:
63
  logger.error(f"Embedding Error: {str(e)}")
@@ -167,16 +167,12 @@ class EnhancedRetriever:
167
  @lru_cache(maxsize=500)
168
  def _hyde_expansion(self, query: str) -> str:
169
  try:
170
- response = client.chat.completions.create(
171
- model=generation_model,
172
- messages=[{
173
- "role": "user",
174
- "content": f"Generate a technical draft about biomimicry for: {query}\nInclude domain-specific terms."
175
- }],
176
- temperature=0.5,
177
- max_tokens=200
178
  )
179
- return response.choices[0].message.content
180
  except Exception as e:
181
  logger.error(f"HyDE Error: {str(e)}")
182
  return query
@@ -212,23 +208,18 @@ SYSTEM_PROMPT = """**Biomimicry Expert Guidelines**
212
  2. Cite sources as [Source]
213
  3. **Bold** technical terms
214
  4. Include reference links
215
-
216
  Context: {context}"""
217
 
218
  @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=2, max=20))
219
  def get_ai_response(query: str, context: str) -> str:
220
  try:
221
- response = client.chat.completions.create(
222
- model=generation_model,
223
- messages=[
224
- {"role": "system", "content": SYSTEM_PROMPT.format(context=context)},
225
- {"role": "user", "content": f"Question: {query}\nProvide a detailed technical answer:"}
226
- ],
227
- temperature=0.4,
228
- max_tokens=2000 # Increased max_tokens
229
  )
230
- logger.info(f"Raw Response: {response.choices[0].message.content}") # Log raw response
231
- return _postprocess_response(response.choices[0].message.content)
232
  except Exception as e:
233
  logger.error(f"Generation Error: {str(e)}")
234
  return "I'm unable to generate a response right now. Please try again later."
 
6
  import pickle
7
  from typing import List, Tuple, Optional
8
  import gradio as gr
 
9
  from functools import lru_cache
10
  from tenacity import retry, stop_after_attempt, wait_exponential
11
  from langchain_community.retrievers import BM25Retriever
 
14
  from langchain_core.documents import Document
15
  from collections import defaultdict
16
  import hashlib
17
+ from tqdm import tqdm
18
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
19
+ from sentence_transformers import SentenceTransformer
20
+ from huggingface_hub import login
21
 
22
  # --- Configuration ---
23
  FAISS_INDEX_PATH = "faiss_index"
24
  BM25_INDEX_PATH = "bm25_index.pkl"
25
  CACHE_VERSION = "v1" # Increment when data format changes
26
+ embedding_model_name = "intfloat/e5-mistral-7b-instruct"
27
+ generation_model_name = "meta-llama/Meta-Llama-3-70B-Instruct"
28
  data_file_name = "AskNatureNet_data_enhanced.json"
 
 
 
 
29
  CHUNK_SIZE = 800
30
  OVERLAP = 200
31
  EMBEDDING_BATCH_SIZE = 32 # Batch size for embedding API calls
32
 
33
+ # Login to Hugging Face Hub
34
+ login(token="llama3")
35
+
36
+ # Initialize models
37
+ embedding_model = SentenceTransformer(embedding_model_name)
38
+ tokenizer = AutoTokenizer.from_pretrained(generation_model_name)
39
+ generation_model = AutoModelForCausalLM.from_pretrained(generation_model_name)
40
+ generation_pipeline = pipeline("text-generation", model=generation_model, tokenizer=tokenizer)
41
+
42
  logging.basicConfig(level=logging.INFO)
43
  logger = logging.getLogger(__name__)
44
 
 
57
  # Process in batches with progress tracking
58
  for i in tqdm(range(0, len(texts), EMBEDDING_BATCH_SIZE), desc="Embedding Progress"):
59
  batch = texts[i:i + EMBEDDING_BATCH_SIZE]
60
+ embeddings.extend(embedding_model.encode(batch))
 
 
 
 
 
61
  return embeddings
62
  except Exception as e:
63
  logger.error(f"Embedding Error: {str(e)}")
 
167
  @lru_cache(maxsize=500)
168
  def _hyde_expansion(self, query: str) -> str:
169
  try:
170
+ response = generation_pipeline(
171
+ f"Generate a technical draft about biomimicry for: {query}\nInclude domain-specific terms.",
172
+ max_length=200,
173
+ temperature=0.5
 
 
 
 
174
  )
175
+ return response[0]['generated_text']
176
  except Exception as e:
177
  logger.error(f"HyDE Error: {str(e)}")
178
  return query
 
208
  2. Cite sources as [Source]
209
  3. **Bold** technical terms
210
  4. Include reference links
 
211
  Context: {context}"""
212
 
213
  @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=2, max=20))
214
  def get_ai_response(query: str, context: str) -> str:
215
  try:
216
+ response = generation_pipeline(
217
+ f"{SYSTEM_PROMPT.format(context=context)}\nQuestion: {query}\nProvide a detailed technical answer:",
218
+ max_length=2000,
219
+ temperature=0.4
 
 
 
 
220
  )
221
+ logger.info(f"Raw Response: {response[0]['generated_text']}") # Log raw response
222
+ return _postprocess_response(response[0]['generated_text'])
223
  except Exception as e:
224
  logger.error(f"Generation Error: {str(e)}")
225
  return "I'm unable to generate a response right now. Please try again later."