import gradio as gr import torch from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline from langchain_community.vectorstores import FAISS from langchain_text_splitters import RecursiveCharacterTextSplitter from langchain_core.prompts import ChatPromptTemplate from langchain_core.runnables import RunnablePassthrough from langchain_core.output_parsers import StrOutputParser def build_chain(): embed_model_id = "BAAI/bge-small-en-v1.5" embeddings = HuggingFaceEmbeddings( model_name=embed_model_id, model_kwargs={"device": "cuda" if torch.cuda.is_available() else "cpu"} ) texts = [ "Kragujevac is a city in central Serbia founded in the 15th century.", "The main industry in Kragujevac includes automotive manufacturing.", "Famous landmarks: The Šumarice Memorial Park and the Old Foundry Museum." ] text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=80) docs = text_splitter.create_documents(texts) vectorstore = FAISS.from_documents(docs, embeddings) retriever = vectorstore.as_retriever(search_kwargs={"k": 3}) model_id = "Qwen/Qwen2.5-1.5B-Instruct" tokenizer = AutoTokenizer.from_pretrained(model_id) # sigurnosno: ako nema pad token if tokenizer.pad_token_id is None: tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained( model_id, device_map="auto", # menjaš u "auto" ako imaš GPU space torch_dtype=torch.float16 ) pipe = pipeline( "text-generation", model=model, tokenizer=tokenizer, max_new_tokens=200, temperature=0.7, do_sample=True, return_full_text=False ) llm = HuggingFacePipeline(pipeline=pipe) template = """You are a helpful assistant. Use only the provided context to answer. If unsure, say "I don't know." Context: {context} Question: {question} Answer:""" prompt = ChatPromptTemplate.from_template(template) def format_docs(docs): return "\n\n".join(doc.page_content for doc in docs) rag_chain = ( {"context": retriever | format_docs, "question": RunnablePassthrough()} | prompt | llm | StrOutputParser() ) return rag_chain rag_chain = build_chain() def answer(question: str): if not question.strip(): return "" return rag_chain.invoke(question) demo = gr.Interface( fn=answer, inputs=gr.Textbox(lines=2, label="Question"), outputs=gr.Textbox(lines=8, label="Answer"), title="Mini RAG demo (Kragujevac)" ) if __name__ == "__main__": demo.launch()