Spaces:
Runtime error
Runtime error
Commit ·
dff1399
1
Parent(s): 2b24b35
updates
Browse files
backend/main.py
CHANGED
|
@@ -24,7 +24,7 @@ from langchain.chains import ConversationalRetrievalChain
|
|
| 24 |
from langchain.prompts import PromptTemplate
|
| 25 |
from langchain_huggingface import HuggingFacePipeline, HuggingFaceEmbeddings
|
| 26 |
|
| 27 |
-
from transformers import AutoTokenizer, pipeline, AutoModelForCausalLM, AutoModelForSeq2SeqLM
|
| 28 |
from sentence_transformers import CrossEncoder
|
| 29 |
|
| 30 |
from whoosh import index
|
|
@@ -55,12 +55,13 @@ class Settings(BaseSettings):
|
|
| 55 |
vectorstore_path: str = "gs://mda_eu_project/vectorstore_index"
|
| 56 |
# Models
|
| 57 |
embedding_model: str = "sentence-transformers/LaBSE"
|
| 58 |
-
llm_model: str = "
|
|
|
|
| 59 |
cross_encoder_model: str = "cross-encoder/mmarco-mMiniLMv2-L12-H384-v1"
|
| 60 |
# RAG parameters
|
| 61 |
chunk_size: int = 750
|
| 62 |
chunk_overlap: int = 100
|
| 63 |
-
hybrid_k: int =
|
| 64 |
assistant_role: str = (
|
| 65 |
"You are a knowledgeable project analyst. You have access to the following retrieved document snippets (with Project IDs in [brackets])"
|
| 66 |
)
|
|
@@ -644,24 +645,39 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
|
| 644 |
# Seq2seq pipeline
|
| 645 |
logger.info("Initializing Pipeline")
|
| 646 |
#full_model=AutoModelForSeq2SeqLM.from_pretrained(settings.llm_model)
|
| 647 |
-
full_model = AutoModelForCausalLM.from_pretrained(settings.llm_model)#, device_map="auto")
|
| 648 |
|
| 649 |
# Apply dynamic quantization to all Linear layers
|
| 650 |
-
llm_model = torch.quantization.quantize_dynamic(
|
| 651 |
-
|
| 652 |
-
|
| 653 |
-
|
| 654 |
-
)
|
| 655 |
# Create your text-generation pipeline on CPU
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 656 |
gen_pipe = pipeline(
|
| 657 |
-
"
|
| 658 |
-
model=
|
| 659 |
-
tokenizer=
|
| 660 |
-
device=-1,
|
| 661 |
max_new_tokens=256,
|
| 662 |
do_sample=True,
|
| 663 |
temperature=0.7,
|
| 664 |
-
#device_map="auto"
|
| 665 |
)
|
| 666 |
# Wrap in LangChain's HuggingFacePipeline
|
| 667 |
llm = HuggingFacePipeline(pipeline=gen_pipe)
|
|
@@ -688,19 +704,24 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
|
| 688 |
logger.info("Initializing Hybrid Retriever")
|
| 689 |
retriever = HybridRetriever(vs=vs, ix=ix, compressor=compressor, cross_encoder=cross_encoder)
|
| 690 |
|
| 691 |
-
prompt = PromptTemplate.from_template(
|
| 692 |
-
|
| 693 |
-
|
| 694 |
-
|
| 695 |
-
|
| 696 |
-
|
| 697 |
-
|
| 698 |
-
|
| 699 |
-
|
| 700 |
-
|
| 701 |
-
|
| 702 |
-
|
| 703 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 704 |
|
| 705 |
logger.info("Initializing Retrieval Chain")
|
| 706 |
app.state.rag_chain = ConversationalRetrievalChain.from_llm(
|
|
|
|
| 24 |
from langchain.prompts import PromptTemplate
|
| 25 |
from langchain_huggingface import HuggingFacePipeline, HuggingFaceEmbeddings
|
| 26 |
|
| 27 |
+
from transformers import AutoTokenizer, pipeline, AutoModelForCausalLM, AutoModelForSeq2SeqLM, T5Tokenizer,T5ForConditionalGeneration
|
| 28 |
from sentence_transformers import CrossEncoder
|
| 29 |
|
| 30 |
from whoosh import index
|
|
|
|
| 55 |
vectorstore_path: str = "gs://mda_eu_project/vectorstore_index"
|
| 56 |
# Models
|
| 57 |
embedding_model: str = "sentence-transformers/LaBSE"
|
| 58 |
+
llm_model: str = "google/flan-t5-base"
|
| 59 |
+
#"google/mt5-base"#"meta-llama/Llama-3.2-1B-Instruct"#"meta-llama/Llama-3.2-3B-Instruct"#"google/flan-t5-base"#"google/mt5-base"#"bigscience/bloomz-560m"#"bigscience/bloom-1b7"#"google/mt5-small"#"bigscience/bloom-3b"#"RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w4a16"
|
| 60 |
cross_encoder_model: str = "cross-encoder/mmarco-mMiniLMv2-L12-H384-v1"
|
| 61 |
# RAG parameters
|
| 62 |
chunk_size: int = 750
|
| 63 |
chunk_overlap: int = 100
|
| 64 |
+
hybrid_k: int = 2
|
| 65 |
assistant_role: str = (
|
| 66 |
"You are a knowledgeable project analyst. You have access to the following retrieved document snippets (with Project IDs in [brackets])"
|
| 67 |
)
|
|
|
|
| 645 |
# Seq2seq pipeline
|
| 646 |
logger.info("Initializing Pipeline")
|
| 647 |
#full_model=AutoModelForSeq2SeqLM.from_pretrained(settings.llm_model)
|
| 648 |
+
#full_model = AutoModelForCausalLM.from_pretrained(settings.llm_model)#, device_map="auto")
|
| 649 |
|
| 650 |
# Apply dynamic quantization to all Linear layers
|
| 651 |
+
#llm_model = torch.quantization.quantize_dynamic(
|
| 652 |
+
# full_model,
|
| 653 |
+
# {torch.nn.Linear},
|
| 654 |
+
# dtype=torch.qint8
|
| 655 |
+
#)
|
| 656 |
# Create your text-generation pipeline on CPU
|
| 657 |
+
#gen_pipe = pipeline(
|
| 658 |
+
# "text-generation",#"text2text-generation",##"text2text-generation",
|
| 659 |
+
# model=llm_model,
|
| 660 |
+
# tokenizer=AutoTokenizer.from_pretrained(settings.llm_model),
|
| 661 |
+
# device=-1, # CPU
|
| 662 |
+
# max_new_tokens=256,
|
| 663 |
+
# do_sample=True,
|
| 664 |
+
# temperature=0.7,
|
| 665 |
+
# #device_map="auto"
|
| 666 |
+
#)
|
| 667 |
+
tokenizer = T5Tokenizer.from_pretrained(settings.llm_model)
|
| 668 |
+
model = T5ForConditionalGeneration.from_pretrained(settings.llm_model)
|
| 669 |
+
model = torch.quantization.quantize_dynamic(
|
| 670 |
+
model, {torch.nn.Linear}, dtype=torch.qint8
|
| 671 |
+
)
|
| 672 |
+
|
| 673 |
gen_pipe = pipeline(
|
| 674 |
+
"text2text-generation",
|
| 675 |
+
model=model,
|
| 676 |
+
tokenizer=tokenizer,
|
| 677 |
+
device=-1,
|
| 678 |
max_new_tokens=256,
|
| 679 |
do_sample=True,
|
| 680 |
temperature=0.7,
|
|
|
|
| 681 |
)
|
| 682 |
# Wrap in LangChain's HuggingFacePipeline
|
| 683 |
llm = HuggingFacePipeline(pipeline=gen_pipe)
|
|
|
|
| 704 |
logger.info("Initializing Hybrid Retriever")
|
| 705 |
retriever = HybridRetriever(vs=vs, ix=ix, compressor=compressor, cross_encoder=cross_encoder)
|
| 706 |
|
| 707 |
+
prompt = PromptTemplate.from_template("""
|
| 708 |
+
{assistant_role}
|
| 709 |
+
|
| 710 |
+
You have the following retrieved document snippets (with Project IDs in [brackets]):
|
| 711 |
+
|
| 712 |
+
{context}
|
| 713 |
+
|
| 714 |
+
User Question:
|
| 715 |
+
{question}
|
| 716 |
+
|
| 717 |
+
Please answer thoroughly, following these rules:
|
| 718 |
+
1. Write at least 4-6 full sentences.
|
| 719 |
+
2. Use clear, technical language in full sentences.
|
| 720 |
+
3. Cite any document you reference by including its ID in [brackets] inline.
|
| 721 |
+
4. Conclude with high-level insights or recommendations.
|
| 722 |
+
|
| 723 |
+
Answer:
|
| 724 |
+
""".strip())
|
| 725 |
|
| 726 |
logger.info("Initializing Retrieval Chain")
|
| 727 |
app.state.rag_chain = ConversationalRetrievalChain.from_llm(
|
frontend/src/components/ProjectDetails.tsx
CHANGED
|
@@ -125,8 +125,8 @@ export default function ProjectDetails({
|
|
| 125 |
<Text fontWeight="bold">Acronym</Text>
|
| 126 |
<Text>{project.acronym}</Text>
|
| 127 |
</Box>
|
| 128 |
-
<Box><Text fontWeight="bold">Start Date</Text><Text>{
|
| 129 |
-
<Box><Text fontWeight="bold">End Date</Text><Text>{
|
| 130 |
<Box><Text fontWeight="bold">Funding (EC max)</Text><Text>€{fmtNum(project.ecMaxContribution)}</Text></Box>
|
| 131 |
<Box><Text fontWeight="bold">Total Cost</Text><Text>€{fmtNum(project.totalCost)}</Text></Box>
|
| 132 |
<Box><Text fontWeight="bold">Funding Scheme</Text><Text>{project.fundingScheme}</Text></Box>
|
|
|
|
| 125 |
<Text fontWeight="bold">Acronym</Text>
|
| 126 |
<Text>{project.acronym}</Text>
|
| 127 |
</Box>
|
| 128 |
+
<Box><Text fontWeight="bold">Start Date</Text><Text>{project.startDate.slice(0, 10)}</Text></Box>
|
| 129 |
+
<Box><Text fontWeight="bold">End Date</Text><Text>{project.endDate.slice(0, 10)}</Text></Box>
|
| 130 |
<Box><Text fontWeight="bold">Funding (EC max)</Text><Text>€{fmtNum(project.ecMaxContribution)}</Text></Box>
|
| 131 |
<Box><Text fontWeight="bold">Total Cost</Text><Text>€{fmtNum(project.totalCost)}</Text></Box>
|
| 132 |
<Box><Text fontWeight="bold">Funding Scheme</Text><Text>{project.fundingScheme}</Text></Box>
|
frontend/src/components/ProjectExplorer.tsx
CHANGED
|
@@ -233,7 +233,7 @@ const ProjectExplorer: React.FC<ProjectExplorerProps> = ({
|
|
| 233 |
<Td w="50%" overflow="hidden" textOverflow="ellipsis">{p.title}</Td>
|
| 234 |
<Td w="10%">{p.status}</Td>
|
| 235 |
<Td w="10%">{p.id}</Td>
|
| 236 |
-
<Td w="10%" whiteSpace="nowrap">{
|
| 237 |
<Td w="10%">{p.fundingScheme || '-'}</Td>
|
| 238 |
<Td w="10%">€{fmtNum(p.ecMaxContribution)}</Td>
|
| 239 |
|
|
@@ -318,7 +318,7 @@ const ProjectExplorer: React.FC<ProjectExplorerProps> = ({
|
|
| 318 |
loadingText="Waiting…"
|
| 319 |
size="md"
|
| 320 |
px={6}
|
| 321 |
-
py={
|
| 322 |
>
|
| 323 |
Send
|
| 324 |
</Button>
|
|
|
|
| 233 |
<Td w="50%" overflow="hidden" textOverflow="ellipsis">{p.title}</Td>
|
| 234 |
<Td w="10%">{p.status}</Td>
|
| 235 |
<Td w="10%">{p.id}</Td>
|
| 236 |
+
<Td w="10%" whiteSpace="nowrap">{p.startDate.slice(0, 10)}</Td>
|
| 237 |
<Td w="10%">{p.fundingScheme || '-'}</Td>
|
| 238 |
<Td w="10%">€{fmtNum(p.ecMaxContribution)}</Td>
|
| 239 |
|
|
|
|
| 318 |
loadingText="Waiting…"
|
| 319 |
size="md"
|
| 320 |
px={6}
|
| 321 |
+
py={3}
|
| 322 |
>
|
| 323 |
Send
|
| 324 |
</Button>
|