Spaces:
Build error
Build error
trying to load the models from HF
Browse files
app.py
CHANGED
|
@@ -6,7 +6,6 @@ import os
|
|
| 6 |
import pickle
|
| 7 |
from typing import List, Tuple, Optional
|
| 8 |
import gradio as gr
|
| 9 |
-
from openai import OpenAI
|
| 10 |
from functools import lru_cache
|
| 11 |
from tenacity import retry, stop_after_attempt, wait_exponential
|
| 12 |
from langchain_community.retrievers import BM25Retriever
|
|
@@ -15,25 +14,31 @@ from langchain_core.embeddings import Embeddings
|
|
| 15 |
from langchain_core.documents import Document
|
| 16 |
from collections import defaultdict
|
| 17 |
import hashlib
|
| 18 |
-
from tqdm import tqdm
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
# --- Configuration ---
|
| 21 |
FAISS_INDEX_PATH = "faiss_index"
|
| 22 |
BM25_INDEX_PATH = "bm25_index.pkl"
|
| 23 |
CACHE_VERSION = "v1" # Increment when data format changes
|
| 24 |
-
|
| 25 |
-
|
| 26 |
data_file_name = "AskNatureNet_data_enhanced.json"
|
| 27 |
-
API_CONFIG = {
|
| 28 |
-
"api_key": os.getenv("API_KEY"),
|
| 29 |
-
"base_url": "https://chat-ai.academiccloud.de/v1"
|
| 30 |
-
}
|
| 31 |
CHUNK_SIZE = 800
|
| 32 |
OVERLAP = 200
|
| 33 |
EMBEDDING_BATCH_SIZE = 32 # Batch size for embedding API calls
|
| 34 |
|
| 35 |
-
#
|
| 36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
logging.basicConfig(level=logging.INFO)
|
| 38 |
logger = logging.getLogger(__name__)
|
| 39 |
|
|
@@ -52,12 +57,7 @@ class MistralEmbeddings(Embeddings):
|
|
| 52 |
# Process in batches with progress tracking
|
| 53 |
for i in tqdm(range(0, len(texts), EMBEDDING_BATCH_SIZE), desc="Embedding Progress"):
|
| 54 |
batch = texts[i:i + EMBEDDING_BATCH_SIZE]
|
| 55 |
-
|
| 56 |
-
input=batch,
|
| 57 |
-
model=embedding_model,
|
| 58 |
-
encoding_format="float"
|
| 59 |
-
)
|
| 60 |
-
embeddings.extend([e.embedding for e in response.data])
|
| 61 |
return embeddings
|
| 62 |
except Exception as e:
|
| 63 |
logger.error(f"Embedding Error: {str(e)}")
|
|
@@ -167,16 +167,12 @@ class EnhancedRetriever:
|
|
| 167 |
@lru_cache(maxsize=500)
|
| 168 |
def _hyde_expansion(self, query: str) -> str:
|
| 169 |
try:
|
| 170 |
-
response =
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
"content": f"Generate a technical draft about biomimicry for: {query}\nInclude domain-specific terms."
|
| 175 |
-
}],
|
| 176 |
-
temperature=0.5,
|
| 177 |
-
max_tokens=200
|
| 178 |
)
|
| 179 |
-
return response
|
| 180 |
except Exception as e:
|
| 181 |
logger.error(f"HyDE Error: {str(e)}")
|
| 182 |
return query
|
|
@@ -212,23 +208,18 @@ SYSTEM_PROMPT = """**Biomimicry Expert Guidelines**
|
|
| 212 |
2. Cite sources as [Source]
|
| 213 |
3. **Bold** technical terms
|
| 214 |
4. Include reference links
|
| 215 |
-
|
| 216 |
Context: {context}"""
|
| 217 |
|
| 218 |
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=2, max=20))
|
| 219 |
def get_ai_response(query: str, context: str) -> str:
|
| 220 |
try:
|
| 221 |
-
response =
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
{"role": "user", "content": f"Question: {query}\nProvide a detailed technical answer:"}
|
| 226 |
-
],
|
| 227 |
-
temperature=0.4,
|
| 228 |
-
max_tokens=2000 # Increased max_tokens
|
| 229 |
)
|
| 230 |
-
logger.info(f"Raw Response: {response
|
| 231 |
-
return _postprocess_response(response
|
| 232 |
except Exception as e:
|
| 233 |
logger.error(f"Generation Error: {str(e)}")
|
| 234 |
return "I'm unable to generate a response right now. Please try again later."
|
|
|
|
| 6 |
import pickle
|
| 7 |
from typing import List, Tuple, Optional
|
| 8 |
import gradio as gr
|
|
|
|
| 9 |
from functools import lru_cache
|
| 10 |
from tenacity import retry, stop_after_attempt, wait_exponential
|
| 11 |
from langchain_community.retrievers import BM25Retriever
|
|
|
|
| 14 |
from langchain_core.documents import Document
|
| 15 |
from collections import defaultdict
|
| 16 |
import hashlib
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
|
| 19 |
+
from sentence_transformers import SentenceTransformer
|
| 20 |
+
from huggingface_hub import login
|
| 21 |
|
| 22 |
# --- Configuration ---
|
| 23 |
FAISS_INDEX_PATH = "faiss_index"
|
| 24 |
BM25_INDEX_PATH = "bm25_index.pkl"
|
| 25 |
CACHE_VERSION = "v1" # Increment when data format changes
|
| 26 |
+
embedding_model_name = "intfloat/e5-mistral-7b-instruct"
|
| 27 |
+
generation_model_name = "meta-llama/Meta-Llama-3-70B-Instruct"
|
| 28 |
data_file_name = "AskNatureNet_data_enhanced.json"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
CHUNK_SIZE = 800
|
| 30 |
OVERLAP = 200
|
| 31 |
EMBEDDING_BATCH_SIZE = 32 # Batch size for embedding API calls
|
| 32 |
|
| 33 |
+
# Login to Hugging Face Hub
|
| 34 |
+
login(token="llama3")
|
| 35 |
+
|
| 36 |
+
# Initialize models
|
| 37 |
+
embedding_model = SentenceTransformer(embedding_model_name)
|
| 38 |
+
tokenizer = AutoTokenizer.from_pretrained(generation_model_name)
|
| 39 |
+
generation_model = AutoModelForCausalLM.from_pretrained(generation_model_name)
|
| 40 |
+
generation_pipeline = pipeline("text-generation", model=generation_model, tokenizer=tokenizer)
|
| 41 |
+
|
| 42 |
logging.basicConfig(level=logging.INFO)
|
| 43 |
logger = logging.getLogger(__name__)
|
| 44 |
|
|
|
|
| 57 |
# Process in batches with progress tracking
|
| 58 |
for i in tqdm(range(0, len(texts), EMBEDDING_BATCH_SIZE), desc="Embedding Progress"):
|
| 59 |
batch = texts[i:i + EMBEDDING_BATCH_SIZE]
|
| 60 |
+
embeddings.extend(embedding_model.encode(batch))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
return embeddings
|
| 62 |
except Exception as e:
|
| 63 |
logger.error(f"Embedding Error: {str(e)}")
|
|
|
|
| 167 |
@lru_cache(maxsize=500)
|
| 168 |
def _hyde_expansion(self, query: str) -> str:
|
| 169 |
try:
|
| 170 |
+
response = generation_pipeline(
|
| 171 |
+
f"Generate a technical draft about biomimicry for: {query}\nInclude domain-specific terms.",
|
| 172 |
+
max_length=200,
|
| 173 |
+
temperature=0.5
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
)
|
| 175 |
+
return response[0]['generated_text']
|
| 176 |
except Exception as e:
|
| 177 |
logger.error(f"HyDE Error: {str(e)}")
|
| 178 |
return query
|
|
|
|
| 208 |
2. Cite sources as [Source]
|
| 209 |
3. **Bold** technical terms
|
| 210 |
4. Include reference links
|
|
|
|
| 211 |
Context: {context}"""
|
| 212 |
|
| 213 |
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=2, max=20))
|
| 214 |
def get_ai_response(query: str, context: str) -> str:
|
| 215 |
try:
|
| 216 |
+
response = generation_pipeline(
|
| 217 |
+
f"{SYSTEM_PROMPT.format(context=context)}\nQuestion: {query}\nProvide a detailed technical answer:",
|
| 218 |
+
max_length=2000,
|
| 219 |
+
temperature=0.4
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
)
|
| 221 |
+
logger.info(f"Raw Response: {response[0]['generated_text']}") # Log raw response
|
| 222 |
+
return _postprocess_response(response[0]['generated_text'])
|
| 223 |
except Exception as e:
|
| 224 |
logger.error(f"Generation Error: {str(e)}")
|
| 225 |
return "I'm unable to generate a response right now. Please try again later."
|