Spaces:
Build error
Build error
| import streamlit as st | |
| import torch | |
| from transformers import DistilBertTokenizer, DistilBertForQuestionAnswering | |
| def find_answer(question, context, model, tokenizer): | |
| # Tokenize the input question and context | |
| inputs = tokenizer.encode_plus(question, context, add_special_tokens=True, return_tensors="pt", max_length=512, truncation=True) | |
| input_ids = inputs["input_ids"].tolist()[0] | |
| # Get the logits for the start and end positions | |
| with torch.no_grad(): | |
| output = model(**inputs) | |
| start_logits = output['start_logits'] | |
| end_logits = output['end_logits'] | |
| # Find the start and end positions with the highest probabilities | |
| start_idx = torch.argmax(start_logits) | |
| end_idx = torch.argmax(end_logits) | |
| # Get the answer from the original context using the tokens between start and end positions | |
| answer = tokenizer.decode(input_ids[start_idx:end_idx+1], skip_special_tokens=True) | |
| return answer | |
| def read_file(file_path): | |
| with open(file_path, "rb") as f: | |
| raw_data = f.read() | |
| result = chardet.detect(raw_data) | |
| encoding = result['encoding'] | |
| with open(file_path, "r", encoding=encoding) as file: | |
| context = file.read() | |
| return context | |
| def main(): | |
| st.title("Shree5 GPT: By Tech Ninja Group") | |
| text = st.text_input('Enter your citizenship-related question:') | |
| if text: | |
| # Load the DistilBERT model and tokenizer | |
| model_name = "distilbert-base-uncased" | |
| tokenizer = DistilBertTokenizer.from_pretrained(model_name) | |
| model = DistilBertForQuestionAnswering.from_pretrained(model_name) | |
| # Read the context from the text file | |
| context = read_file("inputed.txt") | |
| # Split the text into smaller segments if needed | |
| segments = [context[i:i+512] for i in range(0, len(context), 512)] | |
| answers = [] | |
| for i, segment in enumerate(segments): | |
| st.write(f"Generating answer for segment {i+1}...") | |
| answer = find_answer(text, segment, model, tokenizer) | |
| answers.append(answer) | |
| st.write("Answer:") | |
| for i, answer in enumerate(answers): | |
| st.write(f"Segment {i+1}: {answer}") | |
| if __name__ == "__main__": | |
| main() | |