Frenchizer commited on
Commit
64248a1
·
1 Parent(s): ec784c3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -34
app.py CHANGED
@@ -24,10 +24,9 @@ context_pipeline = pipeline("zero-shot-classification", model="facebook/bart-lar
24
  def detect_context(input_text):
25
  result = context_pipeline(input_text, candidate_labels=labels)
26
  contexts = [label for label, score in zip(result["labels"], result["scores"]) if score > 0.1]
27
- print(contexts)
28
- return contexts if contexts else ["general"]
29
 
30
- def translate_text(input_text, context):
31
  tokenized_input = tokenizer(
32
  input_text, return_tensors="np", padding=True, truncation=True, max_length=512
33
  )
@@ -37,35 +36,4 @@ def translate_text(input_text, context):
37
  decoder_start_token_id = tokenizer.cls_token_id or tokenizer.pad_token_id
38
  decoder_input_ids = np.array([[decoder_start_token_id]], dtype=np.int64)
39
 
40
- for _ in range(512):
41
- outputs = session.run(
42
- None,
43
- {
44
- "input_ids": input_ids,
45
- "attention_mask": attention_mask,
46
- "decoder_input_ids": decoder_input_ids,
47
- }
48
- )
49
-
50
- logits = outputs[0]
51
- next_token_id = np.argmax(logits[:, -1, :], axis=-1).item()
52
- decoder_input_ids = np.concatenate(
53
- [decoder_input_ids, np.array([[next_token_id]], dtype=np.int64)], axis=1
54
- )
55
-
56
- if next_token_id == tokenizer.eos_token_id:
57
- break
58
-
59
- return tokenizer.decode(decoder_input_ids[0], skip_special_tokens=True)
60
-
61
- def process_request(input_text):
62
- contexts = detect_context(input_text)
63
- translation = translate_text(input_text, contexts[0]) # Use the first detected context
64
- return translation, contexts[0]
65
-
66
- gr.Interface(
67
- fn=process_request,
68
- inputs="text",
69
- outputs="text",
70
- live=True
71
  ).launch()
 
24
  def detect_context(input_text):
25
  result = context_pipeline(input_text, candidate_labels=labels)
26
  contexts = [label for label, score in zip(result["labels"], result["scores"]) if score > 0.1]
27
+ return contexts[0] if contexts else "general"
 
28
 
29
+ def translate_text(input_text):
30
  tokenized_input = tokenizer(
31
  input_text, return_tensors="np", padding=True, truncation=True, max_length=512
32
  )
 
36
  decoder_start_token_id = tokenizer.cls_token_id or tokenizer.pad_token_id
37
  decoder_input_ids = np.array([[decoder_start_token_id]], dtype=np.int64)
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  ).launch()