Spaces:
Runtime error
Runtime error
Initial commit
Browse files
app.py
CHANGED
|
@@ -14,12 +14,10 @@ warnings.filterwarnings("ignore", category=UserWarning, module="torchvision")
|
|
| 14 |
try:
|
| 15 |
tokenizer = AutoTokenizer.from_pretrained("airesearch/wangchanberta-base-att-spm-uncased", use_fast=False)
|
| 16 |
model = AutoModelForMaskedLM.from_pretrained("airesearch/wangchanberta-base-att-spm-uncased")
|
| 17 |
-
model_name = "airesearch/wangchanberta-base-att-spm-uncased"
|
| 18 |
except Exception:
|
| 19 |
st.warning("Switching to xlm-roberta-base model due to compatibility issues.")
|
| 20 |
-
tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base")
|
| 21 |
model = AutoModelForMaskedLM.from_pretrained("xlm-roberta-base")
|
| 22 |
-
model_name = "xlm-roberta-base"
|
| 23 |
|
| 24 |
# Initialize the fill-mask pipeline
|
| 25 |
pipe = pipeline("fill-mask", model=model, tokenizer=tokenizer, framework="pt")
|
|
@@ -68,7 +66,7 @@ Feel free to enter your own sentence with `<mask>` and explore the predictions!
|
|
| 68 |
|
| 69 |
# User input box
|
| 70 |
st.subheader("Input Text")
|
| 71 |
-
input_text = st.text_input("Enter a sentence with `<mask>` to find similar predictions:", "
|
| 72 |
|
| 73 |
# Ensure the input includes a `<mask>`
|
| 74 |
if "<mask>" not in input_text:
|
|
@@ -90,10 +88,8 @@ if input_text:
|
|
| 90 |
result = pipe(input_text)
|
| 91 |
|
| 92 |
for r in result:
|
| 93 |
-
# Adjust based on observed output structure
|
| 94 |
prediction_text = r.get('sequence', '')
|
| 95 |
|
| 96 |
-
# Only proceed if we have a valid prediction text
|
| 97 |
if prediction_text:
|
| 98 |
prediction_embedding = get_embedding(prediction_text)
|
| 99 |
similarity = cosine_similarity(input_embedding, prediction_embedding)[0][0]
|
|
|
|
| 14 |
try:
|
| 15 |
tokenizer = AutoTokenizer.from_pretrained("airesearch/wangchanberta-base-att-spm-uncased", use_fast=False)
|
| 16 |
model = AutoModelForMaskedLM.from_pretrained("airesearch/wangchanberta-base-att-spm-uncased")
|
|
|
|
| 17 |
except Exception:
|
| 18 |
st.warning("Switching to xlm-roberta-base model due to compatibility issues.")
|
| 19 |
+
tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base", use_fast=False)
|
| 20 |
model = AutoModelForMaskedLM.from_pretrained("xlm-roberta-base")
|
|
|
|
| 21 |
|
| 22 |
# Initialize the fill-mask pipeline
|
| 23 |
pipe = pipeline("fill-mask", model=model, tokenizer=tokenizer, framework="pt")
|
|
|
|
| 66 |
|
| 67 |
# User input box
|
| 68 |
st.subheader("Input Text")
|
| 69 |
+
input_text = st.text_input("Enter a sentence with `<mask>` to find similar predictions:", "นักท่องเที่ยวจำนวนมากเลือกที่จะไปเยือน <mask> เพื่อสัมผัสธรรมชาติ")
|
| 70 |
|
| 71 |
# Ensure the input includes a `<mask>`
|
| 72 |
if "<mask>" not in input_text:
|
|
|
|
| 88 |
result = pipe(input_text)
|
| 89 |
|
| 90 |
for r in result:
|
|
|
|
| 91 |
prediction_text = r.get('sequence', '')
|
| 92 |
|
|
|
|
| 93 |
if prediction_text:
|
| 94 |
prediction_embedding = get_embedding(prediction_text)
|
| 95 |
similarity = cosine_similarity(input_embedding, prediction_embedding)[0][0]
|