Spaces:
Build error
Build error
Commit
·
ea96e93
1
Parent(s):
d89edde
Update app.py
Browse files
app.py
CHANGED
|
@@ -8,27 +8,27 @@ session = ort.InferenceSession(MODEL_FILE)
|
|
| 8 |
tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-fr")
|
| 9 |
|
| 10 |
labels = [
|
| 11 |
-
"general", "pharma", "legal", "technical", "UI", "user interface", "medicine",
|
| 12 |
-
"
|
| 13 |
-
"
|
| 14 |
-
"
|
| 15 |
-
"
|
| 16 |
-
"
|
| 17 |
-
"
|
| 18 |
-
"retail", "news", "entertainment", "social media", "automotive", "AI",
|
| 19 |
"machine learning", "cryptocurrency", "blockchain", "philosophy", "anthropology",
|
| 20 |
"archaeology", "data science"
|
| 21 |
]
|
| 22 |
|
| 23 |
context_pipeline = pipeline("zero-shot-classification", model="MoritzLaurer/deberta-v3-base-mnli-fever-anli")
|
| 24 |
-
|
| 25 |
|
| 26 |
def detect_context(input_text):
|
| 27 |
result = context_pipeline(input_text, candidate_labels=labels)
|
| 28 |
-
contexts = [label for label, score in zip(result["labels"], result["scores"]) if score > 0.1]
|
| 29 |
print(contexts)
|
| 30 |
return contexts if contexts else ["general"]
|
| 31 |
|
|
|
|
| 32 |
def translate_text(input_text, context):
|
| 33 |
tokenized_input = tokenizer(
|
| 34 |
input_text, return_tensors="np", padding=True, truncation=True, max_length=512
|
|
@@ -60,14 +60,16 @@ def translate_text(input_text, context):
|
|
| 60 |
|
| 61 |
return tokenizer.decode(decoder_input_ids[0], skip_special_tokens=True)
|
| 62 |
|
|
|
|
| 63 |
def process_request(input_text):
|
| 64 |
contexts = detect_context(input_text)
|
| 65 |
translation = translate_text(input_text, contexts[0]) # Use the first detected context
|
| 66 |
return translation
|
| 67 |
|
|
|
|
| 68 |
gr.Interface(
|
| 69 |
fn=process_request,
|
| 70 |
inputs="text",
|
| 71 |
outputs="text",
|
| 72 |
live=True
|
| 73 |
-
).launch()
|
|
|
|
| 8 |
tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-fr")
|
| 9 |
|
| 10 |
labels = [
|
| 11 |
+
"general", "pharma", "legal", "technical", "UI", "user interface", "medicine", "information technology", "IT", "marketing",
|
| 12 |
+
"e-commerce", "programming", "website", "html", "keywords", "food commerce",
|
| 13 |
+
"personal development", "literature", "poetry", "physics", "chemistry", "biology",
|
| 14 |
+
"theater", "finance", "sports", "education", "politics", "economics", "art",
|
| 15 |
+
"history", "music", "gaming", "aerospace", "engineering", "robotics", "travel",
|
| 16 |
+
"tourism", "healthcare", "psychology", "environment", "fashion", "design",
|
| 17 |
+
"real estate", "retail", "news", "entertainment", "social media", "automotive",
|
|
|
|
| 18 |
"machine learning", "cryptocurrency", "blockchain", "philosophy", "anthropology",
|
| 19 |
"archaeology", "data science"
|
| 20 |
]
|
| 21 |
|
| 22 |
context_pipeline = pipeline("zero-shot-classification", model="MoritzLaurer/deberta-v3-base-mnli-fever-anli")
|
| 23 |
+
|
| 24 |
|
| 25 |
def detect_context(input_text):
|
| 26 |
result = context_pipeline(input_text, candidate_labels=labels)
|
| 27 |
+
contexts = [label for label, score in zip(result["labels"], result["scores"]) if label in labels and score > 0.1]
|
| 28 |
print(contexts)
|
| 29 |
return contexts if contexts else ["general"]
|
| 30 |
|
| 31 |
+
|
| 32 |
def translate_text(input_text, context):
|
| 33 |
tokenized_input = tokenizer(
|
| 34 |
input_text, return_tensors="np", padding=True, truncation=True, max_length=512
|
|
|
|
| 60 |
|
| 61 |
return tokenizer.decode(decoder_input_ids[0], skip_special_tokens=True)
|
| 62 |
|
| 63 |
+
|
| 64 |
def process_request(input_text):
|
| 65 |
contexts = detect_context(input_text)
|
| 66 |
translation = translate_text(input_text, contexts[0]) # Use the first detected context
|
| 67 |
return translation
|
| 68 |
|
| 69 |
+
|
| 70 |
gr.Interface(
|
| 71 |
fn=process_request,
|
| 72 |
inputs="text",
|
| 73 |
outputs="text",
|
| 74 |
live=True
|
| 75 |
+
).launch()
|