Spaces:
Sleeping
Sleeping
| # app.py | |
| import os | |
| import requests | |
| import wikipedia | |
| import gradio as gr | |
| import torch | |
| from functools import lru_cache | |
| from concurrent.futures import ThreadPoolExecutor | |
| from typing import List | |
| from transformers import ( | |
| SeamlessM4TTokenizer, | |
| SeamlessM4TProcessor, | |
| SeamlessM4TForTextToText, | |
| pipeline as hf_pipeline | |
| ) | |
| # ββ 1) Model setup ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| MODEL = "facebook/hf-seamless-m4t-medium" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| tokenizer = SeamlessM4TTokenizer.from_pretrained(MODEL, use_fast=False) | |
| processor = SeamlessM4TProcessor.from_pretrained(MODEL, tokenizer=tokenizer) | |
| m4t_model = SeamlessM4TForTextToText.from_pretrained(MODEL).to(device) | |
| if device == "cuda": | |
| m4t_model = m4t_model.half() # FP16 for faster inference on GPU | |
| m4t_model.eval() | |
| def translate_m4t(text: str, src_iso3: str, tgt_iso3: str, auto_detect=False) -> str: | |
| src = None if auto_detect else src_iso3 | |
| inputs = processor(text=text, src_lang=src, return_tensors="pt").to(device) | |
| tokens = m4t_model.generate(**inputs, tgt_lang=tgt_iso3) | |
| return processor.decode(tokens[0].tolist(), skip_special_tokens=True) | |
| def translate_m4t_batch( | |
| texts: List[str], src_iso3: str, tgt_iso3: str, auto_detect=False | |
| ) -> List[str]: | |
| src = None if auto_detect else src_iso3 | |
| inputs = processor( | |
| text=texts, src_lang=src, return_tensors="pt", padding=True | |
| ).to(device) | |
| tokens = m4t_model.generate( | |
| **inputs, | |
| tgt_lang=tgt_iso3, | |
| max_new_tokens=60, | |
| num_beams=1 | |
| ) | |
| return processor.batch_decode(tokens, skip_special_tokens=True) | |
| # ββ 2) NER pipeline (updated for deprecation) ββββββββββββββββββββββββββββββββ | |
| ner = hf_pipeline( | |
| "ner", | |
| model="dslim/bert-base-NER-uncased", | |
| aggregation_strategy="simple" | |
| ) | |
| # ββ 3) CACHING helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def geocode_cache(place: str): | |
| r = requests.get( | |
| "https://nominatim.openstreetmap.org/search", | |
| params={"q": place, "format": "json", "limit": 1}, | |
| headers={"User-Agent": "iVoiceContext/1.0"} | |
| ).json() | |
| if not r: | |
| return None | |
| return {"lat": float(r[0]["lat"]), "lon": float(r[0]["lon"])} | |
| def fetch_osm_cache(lat: float, lon: float, osm_filter: str, limit: int = 5): | |
| payload = f""" | |
| [out:json][timeout:25]; | |
| ( | |
| node{osm_filter}(around:1000,{lat},{lon}); | |
| way{osm_filter}(around:1000,{lat},{lon}); | |
| ); | |
| out center {limit}; | |
| """ | |
| resp = requests.post( | |
| "https://overpass-api.de/api/interpreter", | |
| data={"data": payload} | |
| ) | |
| elems = resp.json().get("elements", []) | |
| return [ | |
| {"name": e["tags"]["name"]} | |
| for e in elems | |
| if e.get("tags", {}).get("name") | |
| ] | |
| def wiki_summary_cache(name: str) -> str: | |
| try: | |
| return wikipedia.summary(name, sentences=2) | |
| except: | |
| return "No summary available." | |
| # ββ 4) Per-entity worker ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def process_entity(ent) -> dict: | |
| w = ent["word"] | |
| lbl = ent["entity_group"] | |
| if lbl == "LOC": | |
| geo = geocode_cache(w) | |
| if not geo: | |
| return { | |
| "text": w, | |
| "label": lbl, | |
| "type": "location", | |
| "error": "could not geocode" | |
| } | |
| restaurants = fetch_osm_cache(geo["lat"], geo["lon"], '["amenity"="restaurant"]') | |
| attractions = fetch_osm_cache(geo["lat"], geo["lon"], '["tourism"="attraction"]') | |
| return { | |
| "text": w, | |
| "label": lbl, | |
| "type": "location", | |
| "geo": geo, | |
| "restaurants": restaurants, | |
| "attractions": attractions | |
| } | |
| # PERSON / ORG / MISC β Wikipedia | |
| summary = wiki_summary_cache(w) | |
| return {"text": w, "label": lbl, "type": "wiki", "summary": summary} | |
| # ββ 5) Main function ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def get_context( | |
| text: str, | |
| source_lang: str, | |
| output_lang: str, | |
| auto_detect: bool | |
| ): | |
| # a) Ensure English for NER | |
| if auto_detect or source_lang != "eng": | |
| en = translate_m4t(text, source_lang, "eng", auto_detect=auto_detect) | |
| else: | |
| en = text | |
| # b) Run NER + dedupe | |
| ner_out = ner(en) | |
| seen = set() | |
| unique_ents = [] | |
| for ent in ner_out: | |
| w = ent["word"] | |
| if w in seen: | |
| continue | |
| seen.add(w) | |
| unique_ents.append(ent) | |
| # c) Parallel I/O | |
| entities = [] | |
| with ThreadPoolExecutor(max_workers=8) as exe: | |
| futures = [exe.submit(process_entity, ent) for ent in unique_ents] | |
| for fut in futures: | |
| entities.append(fut.result()) | |
| # d) Batch-translate non-English fields | |
| if output_lang != "eng": | |
| to_translate = [] | |
| translations_info = [] | |
| for i, e in enumerate(entities): | |
| if e["type"] == "wiki": | |
| translations_info.append(("summary", i)) | |
| to_translate.append(e["summary"]) | |
| elif e["type"] == "location": | |
| for j, r in enumerate(e["restaurants"]): | |
| translations_info.append(("restaurant", i, j)) | |
| to_translate.append(r["name"]) | |
| for j, a in enumerate(e["attractions"]): | |
| translations_info.append(("attraction", i, j)) | |
| to_translate.append(a["name"]) | |
| translated = translate_m4t_batch(to_translate, "eng", output_lang) | |
| for txt, info in zip(translated, translations_info): | |
| kind = info[0] | |
| if kind == "summary": | |
| _, ei = info | |
| entities[ei]["summary"] = txt | |
| elif kind == "restaurant": | |
| _, ei, ri = info | |
| entities[ei]["restaurants"][ri]["name"] = txt | |
| elif kind == "attraction": | |
| _, ei, ai = info | |
| entities[ei]["attractions"][ai]["name"] = txt | |
| return {"entities": entities} | |
| # ββ 6) Gradio interface βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| iface = gr.Interface( | |
| fn=get_context, | |
| inputs=[ | |
| gr.Textbox(lines=3, placeholder="Enter textβ¦"), | |
| gr.Textbox(label="Source Language (ISO 639-3)"), | |
| gr.Textbox(label="Target Language (ISO 639-3)"), | |
| gr.Checkbox(label="Auto-detect source language") | |
| ], | |
| outputs="json", | |
| title="iVoice Context-Aware", | |
| description="Returns only the detected entities and their related info." | |
| ).queue() # β removed unsupported kwargs | |
| if __name__ == "__main__": | |
| iface.launch( | |
| server_name="0.0.0.0", | |
| server_port=int(os.environ.get("PORT", 7860)), | |
| share=True | |
| ) |