Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM | |
| from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain_core.runnables import RunnablePassthrough | |
| from langchain_core.output_parsers import StrOutputParser | |
| def build_chain(): | |
| embed_model_id = "BAAI/bge-small-en-v1.5" | |
| embeddings = HuggingFaceEmbeddings( | |
| model_name=embed_model_id, | |
| model_kwargs={"device": "cuda" if torch.cuda.is_available() else "cpu"} | |
| ) | |
| texts = [ | |
| "Kragujevac is a city in central Serbia founded in the 15th century.", | |
| "The main industry in Kragujevac includes automotive manufacturing.", | |
| "Famous landmarks: The Šumarice Memorial Park and the Old Foundry Museum." | |
| ] | |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=80) | |
| docs = text_splitter.create_documents(texts) | |
| vectorstore = FAISS.from_documents(docs, embeddings) | |
| retriever = vectorstore.as_retriever(search_kwargs={"k": 3}) | |
| model_id = "Qwen/Qwen2.5-1.5B-Instruct" | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| # sigurnosno: ako nema pad token | |
| if tokenizer.pad_token_id is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| device_map="auto", # menjaš u "auto" ako imaš GPU space | |
| torch_dtype=torch.float16 | |
| ) | |
| pipe = pipeline( | |
| "text-generation", | |
| model=model, | |
| tokenizer=tokenizer, | |
| max_new_tokens=200, | |
| temperature=0.7, | |
| do_sample=True, | |
| return_full_text=False | |
| ) | |
| llm = HuggingFacePipeline(pipeline=pipe) | |
| template = """You are a helpful assistant. Use only the provided context to answer. | |
| If unsure, say "I don't know." | |
| Context: {context} | |
| Question: {question} | |
| Answer:""" | |
| prompt = ChatPromptTemplate.from_template(template) | |
| def format_docs(docs): | |
| return "\n\n".join(doc.page_content for doc in docs) | |
| rag_chain = ( | |
| {"context": retriever | format_docs, "question": RunnablePassthrough()} | |
| | prompt | |
| | llm | |
| | StrOutputParser() | |
| ) | |
| return rag_chain | |
| rag_chain = build_chain() | |
| def answer(question: str): | |
| if not question.strip(): | |
| return "" | |
| return rag_chain.invoke(question) | |
| demo = gr.Interface( | |
| fn=answer, | |
| inputs=gr.Textbox(lines=2, label="Question"), | |
| outputs=gr.Textbox(lines=8, label="Answer"), | |
| title="Mini RAG demo (Kragujevac)" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |