Update app.py
Browse files
app.py
CHANGED
|
@@ -5,21 +5,31 @@ import requests
|
|
| 5 |
import wikipedia
|
| 6 |
import gradio as gr
|
| 7 |
import torch
|
|
|
|
| 8 |
from transformers import (
|
| 9 |
SeamlessM4TProcessor,
|
| 10 |
SeamlessM4TForTextToText,
|
|
|
|
| 11 |
pipeline as hf_pipeline
|
| 12 |
)
|
| 13 |
|
| 14 |
# ββββββββββββββββββββ
|
| 15 |
-
# 1) SeamlessM4T
|
| 16 |
MODEL_NAME = "facebook/hf-seamless-m4t-medium"
|
| 17 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
m4t_model = SeamlessM4TForTextToText.from_pretrained(MODEL_NAME).to(device).eval()
|
| 20 |
|
| 21 |
def translate_m4t(text, src_iso3, tgt_iso3, auto_detect=False):
|
| 22 |
-
# src_iso3: e.g. "eng", "fra", etc. If auto_detect=True, pass None
|
| 23 |
src = None if auto_detect else src_iso3
|
| 24 |
inputs = processor(text=text, src_lang=src, return_tensors="pt").to(device)
|
| 25 |
tokens = m4t_model.generate(**inputs, tgt_lang=tgt_iso3)
|
|
@@ -41,7 +51,8 @@ def geocode(place: str):
|
|
| 41 |
params={"q": place, "format": "json", "limit": 1},
|
| 42 |
headers={"User-Agent":"iVoiceContext/1.0"}
|
| 43 |
).json()
|
| 44 |
-
if not resp:
|
|
|
|
| 45 |
return float(resp[0]["lat"]), float(resp[0]["lon"])
|
| 46 |
|
| 47 |
def fetch_osm(lat, lon, osm_filter, limit=5):
|
|
@@ -63,16 +74,16 @@ def fetch_osm(lat, lon, osm_filter, limit=5):
|
|
| 63 |
|
| 64 |
# ββββββββββββββββββββ
|
| 65 |
def get_context(text: str,
|
| 66 |
-
source_lang: str, # always 3
|
| 67 |
-
output_lang: str, # always 3
|
| 68 |
auto_detect: bool):
|
| 69 |
-
# 1) Ensure English for NER
|
| 70 |
if auto_detect or source_lang != "eng":
|
| 71 |
en_text = translate_m4t(text, source_lang, "eng", auto_detect=auto_detect)
|
| 72 |
else:
|
| 73 |
en_text = text
|
| 74 |
|
| 75 |
-
# 2)
|
| 76 |
ner_out = ner(en_text)
|
| 77 |
ents = { ent["word"]: ent["entity_group"] for ent in ner_out }
|
| 78 |
|
|
@@ -84,25 +95,22 @@ def get_context(text: str,
|
|
| 84 |
results[ent_text] = {"type":"location","error":"could not geocode"}
|
| 85 |
else:
|
| 86 |
lat, lon = geo
|
| 87 |
-
rest = fetch_osm(lat, lon, '["amenity"="restaurant"]')
|
| 88 |
-
attr = fetch_osm(lat, lon, '["tourism"="attraction"]')
|
| 89 |
results[ent_text] = {
|
| 90 |
"type": "location",
|
| 91 |
-
"restaurants":
|
| 92 |
-
"attractions":
|
| 93 |
}
|
| 94 |
else:
|
| 95 |
-
# PERSON, ORG, MISC β Wikipedia
|
| 96 |
try:
|
| 97 |
-
|
| 98 |
except Exception:
|
| 99 |
-
|
| 100 |
-
results[ent_text] = {"type":"wiki","summary":
|
| 101 |
|
| 102 |
if not results:
|
| 103 |
return {"error":"no entities found"}
|
| 104 |
|
| 105 |
-
# 3) Translate
|
| 106 |
if output_lang != "eng":
|
| 107 |
for info in results.values():
|
| 108 |
if info["type"] == "wiki":
|
|
@@ -110,13 +118,11 @@ def get_context(text: str,
|
|
| 110 |
info["summary"], "eng", output_lang, auto_detect=False
|
| 111 |
)
|
| 112 |
elif info["type"] == "location":
|
| 113 |
-
for
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
translated.append({"name": tr})
|
| 119 |
-
info[poi_list] = translated
|
| 120 |
|
| 121 |
return results
|
| 122 |
|
|
@@ -133,9 +139,9 @@ iface = gr.Interface(
|
|
| 133 |
title="iVoice Translate + Context-Aware",
|
| 134 |
description=(
|
| 135 |
"1) Translate your text β English (if needed)\n"
|
| 136 |
-
"2)
|
| 137 |
"3) Geocode LOC β fetch nearby restaurants & attractions\n"
|
| 138 |
-
"4) Fetch Wikipedia summaries
|
| 139 |
"5) Translate **all** results β your target language"
|
| 140 |
)
|
| 141 |
).queue()
|
|
|
|
| 5 |
import wikipedia
|
| 6 |
import gradio as gr
|
| 7 |
import torch
|
| 8 |
+
|
| 9 |
from transformers import (
|
| 10 |
SeamlessM4TProcessor,
|
| 11 |
SeamlessM4TForTextToText,
|
| 12 |
+
SeamlessM4TTokenizer, # <<< import the tokenizer class
|
| 13 |
pipeline as hf_pipeline
|
| 14 |
)
|
| 15 |
|
| 16 |
# ββββββββββββββββββββ
|
| 17 |
+
# 1) Load SeamlessM4T tokenizer (slow) and processor
|
| 18 |
MODEL_NAME = "facebook/hf-seamless-m4t-medium"
|
| 19 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 20 |
+
|
| 21 |
+
# load the slow tokenizer (no conversion attempted)
|
| 22 |
+
tokenizer = SeamlessM4TTokenizer.from_pretrained(MODEL_NAME, use_fast=False)
|
| 23 |
+
|
| 24 |
+
# pass it into the processor so it won't try to convert
|
| 25 |
+
processor = SeamlessM4TProcessor.from_pretrained(
|
| 26 |
+
MODEL_NAME,
|
| 27 |
+
tokenizer=tokenizer
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
m4t_model = SeamlessM4TForTextToText.from_pretrained(MODEL_NAME).to(device).eval()
|
| 31 |
|
| 32 |
def translate_m4t(text, src_iso3, tgt_iso3, auto_detect=False):
|
|
|
|
| 33 |
src = None if auto_detect else src_iso3
|
| 34 |
inputs = processor(text=text, src_lang=src, return_tensors="pt").to(device)
|
| 35 |
tokens = m4t_model.generate(**inputs, tgt_lang=tgt_iso3)
|
|
|
|
| 51 |
params={"q": place, "format": "json", "limit": 1},
|
| 52 |
headers={"User-Agent":"iVoiceContext/1.0"}
|
| 53 |
).json()
|
| 54 |
+
if not resp:
|
| 55 |
+
return None
|
| 56 |
return float(resp[0]["lat"]), float(resp[0]["lon"])
|
| 57 |
|
| 58 |
def fetch_osm(lat, lon, osm_filter, limit=5):
|
|
|
|
| 74 |
|
| 75 |
# ββββββββββββββββββββ
|
| 76 |
def get_context(text: str,
|
| 77 |
+
source_lang: str, # always ISO639-3, e.g. "eng"
|
| 78 |
+
output_lang: str, # always ISO639-3, e.g. "fra"
|
| 79 |
auto_detect: bool):
|
| 80 |
+
# 1) Ensure English text for NER
|
| 81 |
if auto_detect or source_lang != "eng":
|
| 82 |
en_text = translate_m4t(text, source_lang, "eng", auto_detect=auto_detect)
|
| 83 |
else:
|
| 84 |
en_text = text
|
| 85 |
|
| 86 |
+
# 2) Run NER
|
| 87 |
ner_out = ner(en_text)
|
| 88 |
ents = { ent["word"]: ent["entity_group"] for ent in ner_out }
|
| 89 |
|
|
|
|
| 95 |
results[ent_text] = {"type":"location","error":"could not geocode"}
|
| 96 |
else:
|
| 97 |
lat, lon = geo
|
|
|
|
|
|
|
| 98 |
results[ent_text] = {
|
| 99 |
"type": "location",
|
| 100 |
+
"restaurants": fetch_osm(lat, lon, '["amenity"="restaurant"]'),
|
| 101 |
+
"attractions": fetch_osm(lat, lon, '["tourism"="attraction"]'),
|
| 102 |
}
|
| 103 |
else:
|
|
|
|
| 104 |
try:
|
| 105 |
+
summ = wikipedia.summary(ent_text, sentences=2)
|
| 106 |
except Exception:
|
| 107 |
+
summ = "No summary available."
|
| 108 |
+
results[ent_text] = {"type":"wiki","summary": summ}
|
| 109 |
|
| 110 |
if not results:
|
| 111 |
return {"error":"no entities found"}
|
| 112 |
|
| 113 |
+
# 3) Translate all text fields β output_lang
|
| 114 |
if output_lang != "eng":
|
| 115 |
for info in results.values():
|
| 116 |
if info["type"] == "wiki":
|
|
|
|
| 118 |
info["summary"], "eng", output_lang, auto_detect=False
|
| 119 |
)
|
| 120 |
elif info["type"] == "location":
|
| 121 |
+
for key in ("restaurants","attractions"):
|
| 122 |
+
info[key] = [
|
| 123 |
+
{"name": translate_m4t(item["name"], "eng", output_lang)}
|
| 124 |
+
for item in info[key]
|
| 125 |
+
]
|
|
|
|
|
|
|
| 126 |
|
| 127 |
return results
|
| 128 |
|
|
|
|
| 139 |
title="iVoice Translate + Context-Aware",
|
| 140 |
description=(
|
| 141 |
"1) Translate your text β English (if needed)\n"
|
| 142 |
+
"2) Extract LOC/PERSON/ORG via BERT-NER\n"
|
| 143 |
"3) Geocode LOC β fetch nearby restaurants & attractions\n"
|
| 144 |
+
"4) Fetch Wikipedia summaries\n"
|
| 145 |
"5) Translate **all** results β your target language"
|
| 146 |
)
|
| 147 |
).queue()
|