RISE / app.py
Nolsafan's picture
Changed model to try to resolve the error (#1)
2d6cd36
import gradio as gr
import torch
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline
from langchain_community.vectorstores import FAISS
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
def build_chain():
embed_model_id = "BAAI/bge-small-en-v1.5"
embeddings = HuggingFaceEmbeddings(
model_name=embed_model_id,
model_kwargs={"device": "cuda" if torch.cuda.is_available() else "cpu"}
)
texts = [
"Kragujevac is a city in central Serbia founded in the 15th century.",
"The main industry in Kragujevac includes automotive manufacturing.",
"Famous landmarks: The Šumarice Memorial Park and the Old Foundry Museum."
]
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=80)
docs = text_splitter.create_documents(texts)
vectorstore = FAISS.from_documents(docs, embeddings)
retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
model_id = "Qwen/Qwen2.5-1.5B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id)
# sigurnosno: ako nema pad token
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto", # menjaš u "auto" ako imaš GPU space
torch_dtype=torch.float16
)
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=200,
temperature=0.7,
do_sample=True,
return_full_text=False
)
llm = HuggingFacePipeline(pipeline=pipe)
template = """You are a helpful assistant. Use only the provided context to answer.
If unsure, say "I don't know."
Context: {context}
Question: {question}
Answer:"""
prompt = ChatPromptTemplate.from_template(template)
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
rag_chain = (
{"context": retriever | format_docs, "question": RunnablePassthrough()}
| prompt
| llm
| StrOutputParser()
)
return rag_chain
rag_chain = build_chain()
def answer(question: str):
if not question.strip():
return ""
return rag_chain.invoke(question)
demo = gr.Interface(
fn=answer,
inputs=gr.Textbox(lines=2, label="Question"),
outputs=gr.Textbox(lines=8, label="Answer"),
title="Mini RAG demo (Kragujevac)"
)
if __name__ == "__main__":
demo.launch()