Spaces:
Build error
Build error
File size: 2,243 Bytes
8feebbe 98eb250 8feebbe 98eb250 8feebbe 98eb250 8feebbe 98eb250 8feebbe | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 | import streamlit as st
import torch
from transformers import DistilBertTokenizer, DistilBertForQuestionAnswering
def find_answer(question, context, model, tokenizer):
# Tokenize the input question and context
inputs = tokenizer.encode_plus(question, context, add_special_tokens=True, return_tensors="pt", max_length=512, truncation=True)
input_ids = inputs["input_ids"].tolist()[0]
# Get the logits for the start and end positions
with torch.no_grad():
output = model(**inputs)
start_logits = output['start_logits']
end_logits = output['end_logits']
# Find the start and end positions with the highest probabilities
start_idx = torch.argmax(start_logits)
end_idx = torch.argmax(end_logits)
# Get the answer from the original context using the tokens between start and end positions
answer = tokenizer.decode(input_ids[start_idx:end_idx+1], skip_special_tokens=True)
return answer
def read_file(file_path):
with open(file_path, "rb") as f:
raw_data = f.read()
result = chardet.detect(raw_data)
encoding = result['encoding']
with open(file_path, "r", encoding=encoding) as file:
context = file.read()
return context
def main():
st.title("Shree5 GPT: By Tech Ninja Group")
text = st.text_input('Enter your citizenship-related question:')
if text:
# Load the DistilBERT model and tokenizer
model_name = "distilbert-base-uncased"
tokenizer = DistilBertTokenizer.from_pretrained(model_name)
model = DistilBertForQuestionAnswering.from_pretrained(model_name)
# Read the context from the text file
context = read_file("inputed.txt")
# Split the text into smaller segments if needed
segments = [context[i:i+512] for i in range(0, len(context), 512)]
answers = []
for i, segment in enumerate(segments):
st.write(f"Generating answer for segment {i+1}...")
answer = find_answer(text, segment, model, tokenizer)
answers.append(answer)
st.write("Answer:")
for i, answer in enumerate(answers):
st.write(f"Segment {i+1}: {answer}")
if __name__ == "__main__":
main()
|