Rom89823974978 commited on
Commit
dff1399
·
1 Parent(s): 2b24b35
backend/main.py CHANGED
@@ -24,7 +24,7 @@ from langchain.chains import ConversationalRetrievalChain
24
  from langchain.prompts import PromptTemplate
25
  from langchain_huggingface import HuggingFacePipeline, HuggingFaceEmbeddings
26
 
27
- from transformers import AutoTokenizer, pipeline, AutoModelForCausalLM, AutoModelForSeq2SeqLM
28
  from sentence_transformers import CrossEncoder
29
 
30
  from whoosh import index
@@ -55,12 +55,13 @@ class Settings(BaseSettings):
55
  vectorstore_path: str = "gs://mda_eu_project/vectorstore_index"
56
  # Models
57
  embedding_model: str = "sentence-transformers/LaBSE"
58
- llm_model: str = "meta-llama/Llama-3.2-1B-Instruct"#"meta-llama/Llama-3.2-3B-Instruct"#"google/flan-t5-base"#"google/mt5-base"#"bigscience/bloomz-560m"#"bigscience/bloom-1b7"#"google/mt5-small"#"bigscience/bloom-3b"#"RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w4a16"
 
59
  cross_encoder_model: str = "cross-encoder/mmarco-mMiniLMv2-L12-H384-v1"
60
  # RAG parameters
61
  chunk_size: int = 750
62
  chunk_overlap: int = 100
63
- hybrid_k: int = 4
64
  assistant_role: str = (
65
  "You are a knowledgeable project analyst. You have access to the following retrieved document snippets (with Project IDs in [brackets])"
66
  )
@@ -644,24 +645,39 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
644
  # Seq2seq pipeline
645
  logger.info("Initializing Pipeline")
646
  #full_model=AutoModelForSeq2SeqLM.from_pretrained(settings.llm_model)
647
- full_model = AutoModelForCausalLM.from_pretrained(settings.llm_model)#, device_map="auto")
648
 
649
  # Apply dynamic quantization to all Linear layers
650
- llm_model = torch.quantization.quantize_dynamic(
651
- full_model,
652
- {torch.nn.Linear},
653
- dtype=torch.qint8
654
- )
655
  # Create your text-generation pipeline on CPU
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
656
  gen_pipe = pipeline(
657
- "text-generation",#"text2text-generation",##"text2text-generation",
658
- model=llm_model,
659
- tokenizer=AutoTokenizer.from_pretrained(settings.llm_model),
660
- device=-1, # CPU
661
  max_new_tokens=256,
662
  do_sample=True,
663
  temperature=0.7,
664
- #device_map="auto"
665
  )
666
  # Wrap in LangChain's HuggingFacePipeline
667
  llm = HuggingFacePipeline(pipeline=gen_pipe)
@@ -688,19 +704,24 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
688
  logger.info("Initializing Hybrid Retriever")
689
  retriever = HybridRetriever(vs=vs, ix=ix, compressor=compressor, cross_encoder=cross_encoder)
690
 
691
- prompt = PromptTemplate.from_template(
692
- f"{settings.assistant_role}\n\n"
693
- "{context}\n"
694
- "Now answer the user's question thoroughly:"
695
- "Question: {question}\n"
696
- "Your answer should: \n"
697
- "1. Be at least **4-6 sentences** long \n"
698
- "2. Explain concepts clearly in full sentences \n"
699
- "3. Cite any document you draw on by including its ID in [brackets] inline \n"
700
- "4. Provide any high-level conclusions or recommendations at the end \n"
701
-
702
- "Begin your answer below:"
703
- )
 
 
 
 
 
704
 
705
  logger.info("Initializing Retrieval Chain")
706
  app.state.rag_chain = ConversationalRetrievalChain.from_llm(
 
24
  from langchain.prompts import PromptTemplate
25
  from langchain_huggingface import HuggingFacePipeline, HuggingFaceEmbeddings
26
 
27
+ from transformers import AutoTokenizer, pipeline, AutoModelForCausalLM, AutoModelForSeq2SeqLM, T5Tokenizer,T5ForConditionalGeneration
28
  from sentence_transformers import CrossEncoder
29
 
30
  from whoosh import index
 
55
  vectorstore_path: str = "gs://mda_eu_project/vectorstore_index"
56
  # Models
57
  embedding_model: str = "sentence-transformers/LaBSE"
58
+ llm_model: str = "google/flan-t5-base"
59
+ #"google/mt5-base"#"meta-llama/Llama-3.2-1B-Instruct"#"meta-llama/Llama-3.2-3B-Instruct"#"google/flan-t5-base"#"google/mt5-base"#"bigscience/bloomz-560m"#"bigscience/bloom-1b7"#"google/mt5-small"#"bigscience/bloom-3b"#"RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w4a16"
60
  cross_encoder_model: str = "cross-encoder/mmarco-mMiniLMv2-L12-H384-v1"
61
  # RAG parameters
62
  chunk_size: int = 750
63
  chunk_overlap: int = 100
64
+ hybrid_k: int = 2
65
  assistant_role: str = (
66
  "You are a knowledgeable project analyst. You have access to the following retrieved document snippets (with Project IDs in [brackets])"
67
  )
 
645
  # Seq2seq pipeline
646
  logger.info("Initializing Pipeline")
647
  #full_model=AutoModelForSeq2SeqLM.from_pretrained(settings.llm_model)
648
+ #full_model = AutoModelForCausalLM.from_pretrained(settings.llm_model)#, device_map="auto")
649
 
650
  # Apply dynamic quantization to all Linear layers
651
+ #llm_model = torch.quantization.quantize_dynamic(
652
+ # full_model,
653
+ # {torch.nn.Linear},
654
+ # dtype=torch.qint8
655
+ #)
656
  # Create your text-generation pipeline on CPU
657
+ #gen_pipe = pipeline(
658
+ # "text-generation",#"text2text-generation",##"text2text-generation",
659
+ # model=llm_model,
660
+ # tokenizer=AutoTokenizer.from_pretrained(settings.llm_model),
661
+ # device=-1, # CPU
662
+ # max_new_tokens=256,
663
+ # do_sample=True,
664
+ # temperature=0.7,
665
+ # #device_map="auto"
666
+ #)
667
+ tokenizer = T5Tokenizer.from_pretrained(settings.llm_model)
668
+ model = T5ForConditionalGeneration.from_pretrained(settings.llm_model)
669
+ model = torch.quantization.quantize_dynamic(
670
+ model, {torch.nn.Linear}, dtype=torch.qint8
671
+ )
672
+
673
  gen_pipe = pipeline(
674
+ "text2text-generation",
675
+ model=model,
676
+ tokenizer=tokenizer,
677
+ device=-1,
678
  max_new_tokens=256,
679
  do_sample=True,
680
  temperature=0.7,
 
681
  )
682
  # Wrap in LangChain's HuggingFacePipeline
683
  llm = HuggingFacePipeline(pipeline=gen_pipe)
 
704
  logger.info("Initializing Hybrid Retriever")
705
  retriever = HybridRetriever(vs=vs, ix=ix, compressor=compressor, cross_encoder=cross_encoder)
706
 
707
+ prompt = PromptTemplate.from_template("""
708
+ {assistant_role}
709
+
710
+ You have the following retrieved document snippets (with Project IDs in [brackets]):
711
+
712
+ {context}
713
+
714
+ User Question:
715
+ {question}
716
+
717
+ Please answer thoroughly, following these rules:
718
+ 1. Write at least 4-6 full sentences.
719
+ 2. Use clear, technical language in full sentences.
720
+ 3. Cite any document you reference by including its ID in [brackets] inline.
721
+ 4. Conclude with high-level insights or recommendations.
722
+
723
+ Answer:
724
+ """.strip())
725
 
726
  logger.info("Initializing Retrieval Chain")
727
  app.state.rag_chain = ConversationalRetrievalChain.from_llm(
frontend/src/components/ProjectDetails.tsx CHANGED
@@ -125,8 +125,8 @@ export default function ProjectDetails({
125
  <Text fontWeight="bold">Acronym</Text>
126
  <Text>{project.acronym}</Text>
127
  </Box>
128
- <Box><Text fontWeight="bold">Start Date</Text><Text>{new Date(project.startDate).toISOString().slice(0,10)}</Text></Box>
129
- <Box><Text fontWeight="bold">End Date</Text><Text>{new Date(project.endDate).toISOString().slice(0,10)}</Text></Box>
130
  <Box><Text fontWeight="bold">Funding (EC max)</Text><Text>€{fmtNum(project.ecMaxContribution)}</Text></Box>
131
  <Box><Text fontWeight="bold">Total Cost</Text><Text>€{fmtNum(project.totalCost)}</Text></Box>
132
  <Box><Text fontWeight="bold">Funding Scheme</Text><Text>{project.fundingScheme}</Text></Box>
 
125
  <Text fontWeight="bold">Acronym</Text>
126
  <Text>{project.acronym}</Text>
127
  </Box>
128
+ <Box><Text fontWeight="bold">Start Date</Text><Text>{project.startDate.slice(0, 10)}</Text></Box>
129
+ <Box><Text fontWeight="bold">End Date</Text><Text>{project.endDate.slice(0, 10)}</Text></Box>
130
  <Box><Text fontWeight="bold">Funding (EC max)</Text><Text>€{fmtNum(project.ecMaxContribution)}</Text></Box>
131
  <Box><Text fontWeight="bold">Total Cost</Text><Text>€{fmtNum(project.totalCost)}</Text></Box>
132
  <Box><Text fontWeight="bold">Funding Scheme</Text><Text>{project.fundingScheme}</Text></Box>
frontend/src/components/ProjectExplorer.tsx CHANGED
@@ -233,7 +233,7 @@ const ProjectExplorer: React.FC<ProjectExplorerProps> = ({
233
  <Td w="50%" overflow="hidden" textOverflow="ellipsis">{p.title}</Td>
234
  <Td w="10%">{p.status}</Td>
235
  <Td w="10%">{p.id}</Td>
236
- <Td w="10%" whiteSpace="nowrap">{new Date(p.startDate).toISOString().slice(0,10)}</Td>
237
  <Td w="10%">{p.fundingScheme || '-'}</Td>
238
  <Td w="10%">€{fmtNum(p.ecMaxContribution)}</Td>
239
 
@@ -318,7 +318,7 @@ const ProjectExplorer: React.FC<ProjectExplorerProps> = ({
318
  loadingText="Waiting…"
319
  size="md"
320
  px={6}
321
- py={4}
322
  >
323
  Send
324
  </Button>
 
233
  <Td w="50%" overflow="hidden" textOverflow="ellipsis">{p.title}</Td>
234
  <Td w="10%">{p.status}</Td>
235
  <Td w="10%">{p.id}</Td>
236
+ <Td w="10%" whiteSpace="nowrap">{p.startDate.slice(0, 10)}</Td>
237
  <Td w="10%">{p.fundingScheme || '-'}</Td>
238
  <Td w="10%">€{fmtNum(p.ecMaxContribution)}</Td>
239
 
 
318
  loadingText="Waiting…"
319
  size="md"
320
  px={6}
321
+ py={3}
322
  >
323
  Send
324
  </Button>