finalShree5GPT / app.py
nabin2004's picture
done done
98eb250
import streamlit as st
import torch
from transformers import DistilBertTokenizer, DistilBertForQuestionAnswering
def find_answer(question, context, model, tokenizer):
# Tokenize the input question and context
inputs = tokenizer.encode_plus(question, context, add_special_tokens=True, return_tensors="pt", max_length=512, truncation=True)
input_ids = inputs["input_ids"].tolist()[0]
# Get the logits for the start and end positions
with torch.no_grad():
output = model(**inputs)
start_logits = output['start_logits']
end_logits = output['end_logits']
# Find the start and end positions with the highest probabilities
start_idx = torch.argmax(start_logits)
end_idx = torch.argmax(end_logits)
# Get the answer from the original context using the tokens between start and end positions
answer = tokenizer.decode(input_ids[start_idx:end_idx+1], skip_special_tokens=True)
return answer
def read_file(file_path):
with open(file_path, "rb") as f:
raw_data = f.read()
result = chardet.detect(raw_data)
encoding = result['encoding']
with open(file_path, "r", encoding=encoding) as file:
context = file.read()
return context
def main():
st.title("Shree5 GPT: By Tech Ninja Group")
text = st.text_input('Enter your citizenship-related question:')
if text:
# Load the DistilBERT model and tokenizer
model_name = "distilbert-base-uncased"
tokenizer = DistilBertTokenizer.from_pretrained(model_name)
model = DistilBertForQuestionAnswering.from_pretrained(model_name)
# Read the context from the text file
context = read_file("inputed.txt")
# Split the text into smaller segments if needed
segments = [context[i:i+512] for i in range(0, len(context), 512)]
answers = []
for i, segment in enumerate(segments):
st.write(f"Generating answer for segment {i+1}...")
answer = find_answer(text, segment, model, tokenizer)
answers.append(answer)
st.write("Answer:")
for i, answer in enumerate(answers):
st.write(f"Segment {i+1}: {answer}")
if __name__ == "__main__":
main()