Frenchizer commited on
Commit
ea96e93
·
1 Parent(s): d89edde

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -11
app.py CHANGED
@@ -8,27 +8,27 @@ session = ort.InferenceSession(MODEL_FILE)
8
  tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-fr")
9
 
10
  labels = [
11
- "general", "pharma", "legal", "technical", "UI", "user interface", "medicine",
12
- "information technology", "IT", "marketing", "e-commerce", "programming",
13
- "website", "html", "keywords", "food commerce", "personal development",
14
- "literature", "poetry", "physics", "chemistry", "biology", "theater", "finance",
15
- "sports", "education", "politics", "economics", "art", "history", "music",
16
- "gaming", "aerospace", "engineering", "robotics", "travel", "tourism",
17
- "healthcare", "psychology", "environment", "fashion", "design", "real estate",
18
- "retail", "news", "entertainment", "social media", "automotive", "AI",
19
  "machine learning", "cryptocurrency", "blockchain", "philosophy", "anthropology",
20
  "archaeology", "data science"
21
  ]
22
 
23
  context_pipeline = pipeline("zero-shot-classification", model="MoritzLaurer/deberta-v3-base-mnli-fever-anli")
24
- #context_pipeline = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
25
 
26
  def detect_context(input_text):
27
  result = context_pipeline(input_text, candidate_labels=labels)
28
- contexts = [label for label, score in zip(result["labels"], result["scores"]) if score > 0.1]
29
  print(contexts)
30
  return contexts if contexts else ["general"]
31
 
 
32
  def translate_text(input_text, context):
33
  tokenized_input = tokenizer(
34
  input_text, return_tensors="np", padding=True, truncation=True, max_length=512
@@ -60,14 +60,16 @@ def translate_text(input_text, context):
60
 
61
  return tokenizer.decode(decoder_input_ids[0], skip_special_tokens=True)
62
 
 
63
  def process_request(input_text):
64
  contexts = detect_context(input_text)
65
  translation = translate_text(input_text, contexts[0]) # Use the first detected context
66
  return translation
67
 
 
68
  gr.Interface(
69
  fn=process_request,
70
  inputs="text",
71
  outputs="text",
72
  live=True
73
- ).launch()
 
8
  tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-fr")
9
 
10
  labels = [
11
+ "general", "pharma", "legal", "technical", "UI", "user interface", "medicine", "information technology", "IT", "marketing",
12
+ "e-commerce", "programming", "website", "html", "keywords", "food commerce",
13
+ "personal development", "literature", "poetry", "physics", "chemistry", "biology",
14
+ "theater", "finance", "sports", "education", "politics", "economics", "art",
15
+ "history", "music", "gaming", "aerospace", "engineering", "robotics", "travel",
16
+ "tourism", "healthcare", "psychology", "environment", "fashion", "design",
17
+ "real estate", "retail", "news", "entertainment", "social media", "automotive",
 
18
  "machine learning", "cryptocurrency", "blockchain", "philosophy", "anthropology",
19
  "archaeology", "data science"
20
  ]
21
 
22
  context_pipeline = pipeline("zero-shot-classification", model="MoritzLaurer/deberta-v3-base-mnli-fever-anli")
23
+
24
 
25
  def detect_context(input_text):
26
  result = context_pipeline(input_text, candidate_labels=labels)
27
+ contexts = [label for label, score in zip(result["labels"], result["scores"]) if label in labels and score > 0.1]
28
  print(contexts)
29
  return contexts if contexts else ["general"]
30
 
31
+
32
  def translate_text(input_text, context):
33
  tokenized_input = tokenizer(
34
  input_text, return_tensors="np", padding=True, truncation=True, max_length=512
 
60
 
61
  return tokenizer.decode(decoder_input_ids[0], skip_special_tokens=True)
62
 
63
+
64
  def process_request(input_text):
65
  contexts = detect_context(input_text)
66
  translation = translate_text(input_text, contexts[0]) # Use the first detected context
67
  return translation
68
 
69
+
70
  gr.Interface(
71
  fn=process_request,
72
  inputs="text",
73
  outputs="text",
74
  live=True
75
+ ).launch()