Spaces:
Paused
Paused
| import os | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from sentence_transformers import SentenceTransformer | |
| from datasets import load_dataset | |
| import faiss | |
| import gradio as gr | |
| from accelerate import Accelerator | |
| hf_api_key = os.getenv('HF_API_KEY') | |
| model_id = "microsoft/phi-2" | |
| # model_id = "microsoft/Phi-3-mini-128k-instruct" | |
| # ํ ํฌ๋์ด์ ๋ฐ ๋ชจ๋ธ ์ค์ | |
| tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_api_key, trust_remote_code=True) | |
| # ํจ๋ฉ ํ ํฐ ์ค์ | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| token=hf_api_key, | |
| trust_remote_code=True, | |
| torch_dtype=torch.float32 | |
| ) | |
| accelerator = Accelerator() | |
| model = accelerator.prepare(model) | |
| ST = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1") | |
| dataset = load_dataset("not-lain/wikipedia", revision="embedded") | |
| data = dataset["train"] | |
| data = data.add_faiss_index("embeddings") | |
| def generate(formatted_prompt): | |
| prompt_text = f"{SYS_PROMPT} {formatted_prompt}" | |
| encoding = tokenizer(prompt_text, return_tensors="pt", padding="max_length", max_length=512, truncation=True) | |
| input_ids = encoding['input_ids'].to(accelerator.device) | |
| attention_mask = encoding['attention_mask'].to(accelerator.device) | |
| outputs = model.generate( | |
| input_ids, | |
| attention_mask=attention_mask, | |
| max_new_tokens=1024, | |
| eos_token_id=tokenizer.eos_token_id, | |
| do_sample=True, | |
| temperature=0.6, | |
| top_p=0.9 | |
| ) | |
| return tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| def search(query: str, k: int = 3): | |
| embedded_query = ST.encode(query) | |
| scores, retrieved_examples = data.get_nearest_examples("embeddings", embedded_query, k=k) | |
| return scores, retrieved_examples | |
| def format_prompt(prompt, retrieved_documents, k): | |
| PROMPT = f"Question:{prompt}\nContext:" | |
| for idx in range(k): | |
| PROMPT += f"{retrieved_documents['text'][idx]}\n" | |
| return PROMPT | |
| def rag_chatbot_interface(prompt: str, k: int = 2): | |
| scores, retrieved_documents = search(prompt, k) | |
| formatted_prompt = format_prompt(prompt, retrieved_documents, k) | |
| return generate(formatted_prompt) | |
| SYS_PROMPT = "You are an assistant for answering questions. Provide a conversational answer." | |
| iface = gr.Interface( | |
| fn=rag_chatbot_interface, | |
| inputs="text", | |
| outputs="text", | |
| title="Retrieval-Augmented Generation Chatbot", | |
| description="This chatbot provides more accurate answers by searching relevant documents and generating responses." | |
| ) | |
| iface.launch(share=True) | |