andreska's picture
Try to split query into chunks to prevent overloading model
2fc80c5 verified
raw
history blame
1.54 kB
import os
import streamlit as st
from datasets import load_dataset
from huggingface_hub import InferenceClient
api_key = os.getenv("HF_API_KEY")
client = InferenceClient(api_key=api_key)
dataset = load_dataset("andreska/adregadocs", split="test")
def read_dataset(dataset):
text = []
for item in dataset:
text.append(item['text'])
return "\n".join(text)
def chunk_text(text, max_length):
words = text.split()
for i in range(0, len(words), max_length):
yield " ".join(words[i:i+ max_length])
context = read_dataset(dataset)
max_chunk_length = 30000;
st.title("Adrega AI Help")
user_input = st.text_input('Ask me a question')
if st.button("Submit"):
if user_input:
responses = []
for chunk in chunk_text(context, max_chunk_length):
messages = [
{"role": "system", "content": f"Context: {context}"},
{"role": "user","content": user_input}
]
completion = client.chat.completions.create(
model= "Qwen/Qwen2.5-72B-Instruct",
#model="Qwen/Qwen2.5-Coder-32B-Instruct",
#model="HuggingFaceTB/SmolLM2-1.7B-Instruct",
messages=messages,
max_tokens=500
)
answer = completion.choices[0].message['content']
responses.append(answer)
final_response = " ".join(responses)
st.write(f"Adrega AI: {final_response}")
else:
st.write("Please enter a question.")