Spaces:
Runtime error
Runtime error
Update fast_app.py
Browse files- fast_app.py +132 -131
fast_app.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
-
|
| 2 |
import os
|
| 3 |
import json
|
|
|
|
| 4 |
from fastapi import FastAPI, Request, Form, Response
|
| 5 |
from fastapi.responses import HTMLResponse
|
| 6 |
from fastapi.templating import Jinja2Templates
|
|
@@ -8,163 +9,163 @@ from fastapi.staticfiles import StaticFiles
|
|
| 8 |
from fastapi.encoders import jsonable_encoder
|
| 9 |
|
| 10 |
from langchain_community.vectorstores import FAISS
|
|
|
|
|
|
|
| 11 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 12 |
-
|
| 13 |
from langchain.chains import RetrievalQA
|
| 14 |
-
|
| 15 |
-
from langchain.llms import OpenAI
|
| 16 |
from langchain import PromptTemplate
|
| 17 |
-
from langchain.embeddings import OpenAIEmbeddings, HuggingFaceEmbeddings
|
| 18 |
-
|
| 19 |
-
from ingest import Ingest
|
| 20 |
-
|
| 21 |
-
# setx OPENAI_API_KEY "your_openai_api_key_here"
|
| 22 |
|
| 23 |
-
#
|
| 24 |
-
#
|
| 25 |
-
#
|
| 26 |
-
#
|
| 27 |
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
|
|
|
|
|
|
|
| 32 |
|
|
|
|
|
|
|
|
|
|
| 33 |
app = FastAPI()
|
| 34 |
templates = Jinja2Templates(directory="templates")
|
| 35 |
app.mount("/static", StaticFiles(directory="static"), name="static")
|
| 36 |
-
english_embedding_model = "text-embedding-3-large"
|
| 37 |
-
czech_embedding_model = "Seznam/retromae-small-cs"
|
| 38 |
|
| 39 |
-
|
| 40 |
-
|
|
|
|
| 41 |
|
| 42 |
ingestor = Ingest(
|
| 43 |
-
openai_api_key=
|
| 44 |
-
chunk=512,
|
| 45 |
-
overlap=256,
|
| 46 |
-
czech_store=
|
| 47 |
-
english_store=
|
| 48 |
-
czech_embedding_model=
|
| 49 |
-
english_embedding_model=
|
| 50 |
)
|
| 51 |
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
#Pokud odpověď neznáte, prostě řekněte, že to nevíte, nepokoušejte se vymýšlet odpověď.
|
| 81 |
-
|
| 82 |
-
###Kontext: {context}
|
| 83 |
-
###Otázka: {question}
|
| 84 |
-
|
| 85 |
-
Níže vraťte pouze užitečnou odpověď a nic jiného.
|
| 86 |
-
Užitečná odpověď:
|
| 87 |
-
"""
|
| 88 |
-
prompt_cz = PromptTemplate(
|
| 89 |
-
template=prompt_template_cz, input_variables=["context", "question"]
|
| 90 |
-
)
|
| 91 |
-
print("\n Prompt ready... \n\n")
|
| 92 |
-
return prompt_cz
|
| 93 |
-
|
| 94 |
-
|
| 95 |
@app.get("/", response_class=HTMLResponse)
|
| 96 |
-
def
|
| 97 |
return templates.TemplateResponse("index.html", {"request": request})
|
| 98 |
|
| 99 |
-
|
| 100 |
@app.post("/ingest_data")
|
| 101 |
async def ingest_data(folderPath: str = Form(...), language: str = Form(...)):
|
| 102 |
-
|
| 103 |
-
if language == "czech":
|
| 104 |
-
print("\n Czech language selected....\n\n")
|
| 105 |
ingestor.data_czech = folderPath
|
| 106 |
ingestor.ingest_czech()
|
| 107 |
-
message
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
ingestor.ingest_english()
|
| 112 |
-
message = "English data ingestion complete."
|
| 113 |
-
|
| 114 |
-
return {"message": message}
|
| 115 |
-
|
| 116 |
|
| 117 |
@app.post("/get_response")
|
| 118 |
async def get_response(query: str = Form(...), language: str = Form(...)):
|
| 119 |
-
print(language)
|
| 120 |
-
if language == "czech":
|
| 121 |
-
prompt = prompt_cz()
|
| 122 |
-
print("\n Czech language selected....\n\n")
|
| 123 |
-
embedding_model = czech_embedding_model
|
| 124 |
-
persist_directory = czech_store
|
| 125 |
-
model_name = embedding_model
|
| 126 |
-
model_kwargs = {"device": "cpu"}
|
| 127 |
-
encode_kwargs = {"normalize_embeddings": False}
|
| 128 |
-
embedding = HuggingFaceEmbeddings(
|
| 129 |
-
model_name=model_name,
|
| 130 |
-
model_kwargs=model_kwargs,
|
| 131 |
-
encode_kwargs=encode_kwargs,
|
| 132 |
-
)
|
| 133 |
-
else:
|
| 134 |
-
prompt = prompt_en()
|
| 135 |
-
print("\n English language selected....\n\n")
|
| 136 |
-
embedding_model = english_embedding_model # Default to English
|
| 137 |
-
persist_directory = english_store
|
| 138 |
-
embedding = OpenAIEmbeddings(
|
| 139 |
-
openai_api_key=openai_api_key,
|
| 140 |
-
model=embedding_model,
|
| 141 |
-
)
|
| 142 |
-
|
| 143 |
-
vectordb = FAISS.load_local(persist_directory, embedding)
|
| 144 |
-
retriever = vectordb.as_retriever(search_kwargs={"k": 2})
|
| 145 |
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
retriever=retriever,
|
| 151 |
-
return_source_documents=True,
|
| 152 |
-
chain_type_kwargs=chain_type_kwargs,
|
| 153 |
-
verbose=True,
|
| 154 |
-
)
|
| 155 |
-
response = qa_chain(query)
|
| 156 |
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
|
|
|
|
|
|
|
|
|
| 161 |
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
|
|
|
|
|
|
|
|
|
| 167 |
)
|
| 168 |
|
| 169 |
-
|
| 170 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# backend/main.py
|
| 2 |
import os
|
| 3 |
import json
|
| 4 |
+
from dotenv import load_dotenv
|
| 5 |
from fastapi import FastAPI, Request, Form, Response
|
| 6 |
from fastapi.responses import HTMLResponse
|
| 7 |
from fastapi.templating import Jinja2Templates
|
|
|
|
| 9 |
from fastapi.encoders import jsonable_encoder
|
| 10 |
|
| 11 |
from langchain_community.vectorstores import FAISS
|
| 12 |
+
from langchain_community.llms import HuggingFacePipeline # NEW
|
| 13 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
| 14 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
|
|
| 15 |
from langchain.chains import RetrievalQA
|
|
|
|
|
|
|
| 16 |
from langchain import PromptTemplate
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
+
# -------- optional OpenAI imports (kept, but disabled) ----------
|
| 19 |
+
# from langchain.llms import OpenAI
|
| 20 |
+
# from langchain.embeddings import OpenAIEmbeddings
|
| 21 |
+
# ---------------------------------------------------------------
|
| 22 |
|
| 23 |
+
from ingest import Ingest
|
| 24 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
|
| 25 |
+
import torch
|
| 26 |
+
|
| 27 |
+
# ------------------------------------------------------------------
|
| 28 |
+
# 1. ENVIRONMENT
|
| 29 |
+
# ------------------------------------------------------------------
|
| 30 |
+
load_dotenv()
|
| 31 |
+
HF_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
|
| 32 |
+
if HF_TOKEN is None:
|
| 33 |
+
raise ValueError("HUGGINGFACE_TOKEN not set in the environment.")
|
| 34 |
+
|
| 35 |
+
# OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") # Optional
|
| 36 |
+
# if OPENAI_API_KEY is None:
|
| 37 |
+
# print("OpenAI key missing – OpenAI path disabled.")
|
| 38 |
+
|
| 39 |
+
# ------------------------------------------------------------------
|
| 40 |
+
# 2. LLM & EMBEDDINGS CONFIGURATION
|
| 41 |
+
# ------------------------------------------------------------------
|
| 42 |
+
DEFAULT_LLM = "google/gemma-3-4b-it" # change here if desired
|
| 43 |
+
EMB_EN = "sentence-transformers/all-MiniLM-L6-v2"
|
| 44 |
+
EMB_CZ = "Seznam/retromae-small-cs"
|
| 45 |
+
|
| 46 |
+
def build_hf_llm(model_id: str = DEFAULT_LLM) -> HuggingFacePipeline:
|
| 47 |
+
"""
|
| 48 |
+
Creates a HuggingFacePipeline wrapped inside LangChain's LLM interface.
|
| 49 |
+
Works on CPU; uses half precision automatically when CUDA is available.
|
| 50 |
+
"""
|
| 51 |
+
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
| 52 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 53 |
+
model_id,
|
| 54 |
+
token = HF_TOKEN,
|
| 55 |
+
torch_dtype = dtype,
|
| 56 |
+
device_map = "auto"
|
| 57 |
+
)
|
| 58 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id, token=HF_TOKEN)
|
| 59 |
+
gen_pipe = pipeline(
|
| 60 |
+
task = "text-generation",
|
| 61 |
+
model = model,
|
| 62 |
+
tokenizer = tokenizer,
|
| 63 |
+
max_new_tokens = 512,
|
| 64 |
+
temperature = 0.2,
|
| 65 |
+
top_p = 0.95,
|
| 66 |
+
)
|
| 67 |
+
return HuggingFacePipeline(pipeline=gen_pipe)
|
| 68 |
|
| 69 |
+
HF_LLM = build_hf_llm() # Initialise once; reuse in every request
|
| 70 |
+
# OPENAI_LLM = OpenAI(openai_api_key=OPENAI_API_KEY, temperature=0) # optional
|
| 71 |
|
| 72 |
+
# ------------------------------------------------------------------
|
| 73 |
+
# 3. FASTAPI PLUMBING
|
| 74 |
+
# ------------------------------------------------------------------
|
| 75 |
app = FastAPI()
|
| 76 |
templates = Jinja2Templates(directory="templates")
|
| 77 |
app.mount("/static", StaticFiles(directory="static"), name="static")
|
|
|
|
|
|
|
| 78 |
|
| 79 |
+
# Embedding stores
|
| 80 |
+
CZECH_STORE = "stores/czech_512"
|
| 81 |
+
ENGLISH_STORE = "stores/english_512"
|
| 82 |
|
| 83 |
ingestor = Ingest(
|
| 84 |
+
# openai_api_key = OPENAI_API_KEY, # still needed only if you ingest via OpenAI embeds
|
| 85 |
+
chunk = 512,
|
| 86 |
+
overlap = 256,
|
| 87 |
+
czech_store = CZECH_STORE,
|
| 88 |
+
english_store = ENGLISH_STORE,
|
| 89 |
+
czech_embedding_model = EMB_CZ,
|
| 90 |
+
english_embedding_model = EMB_EN,
|
| 91 |
)
|
| 92 |
|
| 93 |
+
# ------------------------------------------------------------------
|
| 94 |
+
# 4. PROMPTS
|
| 95 |
+
# ------------------------------------------------------------------
|
| 96 |
+
def prompt_en() -> PromptTemplate:
|
| 97 |
+
tmpl = """You are an electrical engineer and you answer users' ###Question.
|
| 98 |
+
# Your answer must be helpful, relevant and closely related to the user's ###Question.
|
| 99 |
+
# Quote literally from the ###Context wherever possible.
|
| 100 |
+
# Use your own words only to connect or clarify. If you don't know, say so.
|
| 101 |
+
###Context: {context}
|
| 102 |
+
###Question: {question}
|
| 103 |
+
Helpful answer:
|
| 104 |
+
"""
|
| 105 |
+
return PromptTemplate(template=tmpl, input_variables=["context", "question"])
|
| 106 |
+
|
| 107 |
+
def prompt_cz() -> PromptTemplate:
|
| 108 |
+
tmpl = """Jste elektroinženýr a odpovídáte na ###Otázku.
|
| 109 |
+
# Odpověď musí být užitečná, relevantní a úzce souviset s ###Otázkou.
|
| 110 |
+
# Citujte co nejvíce doslovně z ###Kontextu.
|
| 111 |
+
# Vlastními slovy pouze propojujte nebo vysvětlujte. Nevíte-li, řekněte to.
|
| 112 |
+
###Kontext: {context}
|
| 113 |
+
###Otázka: {question}
|
| 114 |
+
Užitečná odpověď:
|
| 115 |
+
"""
|
| 116 |
+
return PromptTemplate(template=tmpl, input_variables=["context", "question"])
|
| 117 |
+
|
| 118 |
+
# ------------------------------------------------------------------
|
| 119 |
+
# 5. ROUTES
|
| 120 |
+
# ------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
@app.get("/", response_class=HTMLResponse)
|
| 122 |
+
def home(request: Request):
|
| 123 |
return templates.TemplateResponse("index.html", {"request": request})
|
| 124 |
|
|
|
|
| 125 |
@app.post("/ingest_data")
|
| 126 |
async def ingest_data(folderPath: str = Form(...), language: str = Form(...)):
|
| 127 |
+
if language.lower() == "czech":
|
|
|
|
|
|
|
| 128 |
ingestor.data_czech = folderPath
|
| 129 |
ingestor.ingest_czech()
|
| 130 |
+
return {"message": "Czech data ingestion complete."}
|
| 131 |
+
ingestor.data_english = folderPath
|
| 132 |
+
ingestor.ingest_english()
|
| 133 |
+
return {"message": "English data ingestion complete."}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
|
| 135 |
@app.post("/get_response")
|
| 136 |
async def get_response(query: str = Form(...), language: str = Form(...)):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
|
| 138 |
+
is_czech = language.lower() == "czech"
|
| 139 |
+
prompt = prompt_cz() if is_czech else prompt_en()
|
| 140 |
+
store_path = CZECH_STORE if is_czech else ENGLISH_STORE
|
| 141 |
+
embed_name = EMB_CZ if is_czech else EMB_EN
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
|
| 143 |
+
embeddings = HuggingFaceEmbeddings(
|
| 144 |
+
model_name = embed_name,
|
| 145 |
+
model_kwargs = {"device": "cpu"},
|
| 146 |
+
encode_kwargs= {"normalize_embeddings": False}
|
| 147 |
+
)
|
| 148 |
+
vectordb = FAISS.load_local(store_path, embeddings)
|
| 149 |
+
retriever = vectordb.as_retriever(search_kwargs={"k": 2})
|
| 150 |
|
| 151 |
+
qa_chain = RetrievalQA.from_chain_type(
|
| 152 |
+
llm = HF_LLM, # <- default open-source model
|
| 153 |
+
# llm = OPENAI_LLM, # <- optional paid model
|
| 154 |
+
chain_type = "stuff",
|
| 155 |
+
retriever = retriever,
|
| 156 |
+
return_source_documents= True,
|
| 157 |
+
chain_type_kwargs = {"prompt": prompt},
|
| 158 |
+
verbose = True,
|
| 159 |
)
|
| 160 |
|
| 161 |
+
result = qa_chain(query)
|
| 162 |
+
answer = result["result"]
|
| 163 |
+
src_doc = result["source_documents"][0].page_content
|
| 164 |
+
src_path = result["source_documents"][0].metadata["source"]
|
| 165 |
+
|
| 166 |
+
payload = jsonable_encoder(json.dumps({
|
| 167 |
+
"answer" : answer,
|
| 168 |
+
"source_document" : src_doc,
|
| 169 |
+
"doc" : src_path
|
| 170 |
+
}))
|
| 171 |
+
return Response(payload)
|