File size: 1,499 Bytes
765c181
9c8d6cc
826f428
c7ceaa9
 
826f428
b4474c4
 
c7ceaa9
e0aa334
 
c7ceaa9
9c8d6cc
c7ceaa9
 
 
 
 
 
 
 
 
 
9c8d6cc
 
c7ceaa9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
879fd00
c7ceaa9
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from transformers import BertForQuestionAnswering, BertTokenizer
from farasa.segmenter import FarasaSegmenter
import streamlit as st
import torch
import os

checkpoint_path = os.path.abspath("./checkpoint-2817")

# Load model
model = BertForQuestionAnswering.from_pretrained(f"MarioMamdouh121/arabic-qa-model",use_safetensors=True, trust_remote_code=True)
tokenizer = BertTokenizer.from_pretrained(f"MarioMamdouh121/arabic-qa-model", trust_remote_code=True)

segmenter = FarasaSegmenter(interactive=False)

# Streamlit interface
st.title("Arabic Question Answering")
st.write("أدخل سياقًا وسؤالًا بالعربية واحصل على الجواب.")

context = st.text_area("السياق", height=150)
question = st.text_input("السؤال")

if st.button("احصل على الجواب") and context and question:
    # Preprocess
    context_proc = segmenter.segment(context)
    question_proc = segmenter.segment(question)

    # Tokenize
    inputs = tokenizer(
        question_proc,
        context_proc,
        return_tensors="pt",
        truncation=True,
        max_length=512
    )

    with torch.no_grad():
        outputs = model(**inputs)

    start_index = torch.argmax(outputs.start_logits)
    end_index = torch.argmax(outputs.end_logits)
    answer_tokens = inputs["input_ids"][0][start_index : end_index + 1]
    answer = tokenizer.decode(answer_tokens, skip_special_tokens=True)
    answer = segmenter.desegment(answer)

    st.success(f"الجواب: {answer}")