lingua_env / server /lingua_env_environment.py
ykaitao's picture
Upload folder using huggingface_hub
491cae8 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Lingua Env Environment Implementation.
A simple test environment that echoes back messages sent to it.
Perfect for testing HTTP server infrastructure.
"""
import json
import random
import uuid
from pathlib import Path
import urllib.request
from openenv.core.env_server import Environment
from lingua_env.models import LinguaAction, LinguaObservation, LinguaState
data_path = Path("/tmp/ambiguous_examples.jsonl")
DATA_SOURCE = "https://huggingface.co/datasets/ykaitao/lingua_env_data/resolve/main/ambiguous_examples.jsonl"
if not data_path.exists():
print(f"Downloading dataset from {DATA_SOURCE}")
urllib.request.urlretrieve(DATA_SOURCE, data_path)
PYMORPHY_FEATURE_MAP = {
"nomn": ("Case", "Nom"),
"gent": ("Case", "Gen"),
"datv": ("Case", "Dat"),
"accs": ("Case", "Acc"),
"ablt": ("Case", "Ins"),
"loct": ("Case", "Loc"),
"loc2": ("Case", "Loc"),
"sing": ("Number", "Sing"),
"plur": ("Number", "Plur"),
"masc": ("Gender", "Masc"),
"femn": ("Gender", "Fem"),
"neut": ("Gender", "Neut"),
"anim": ("Animacy", "Anim"),
"inan": ("Animacy", "Inan"),
"past": ("Tense", "Past"),
"pres": ("Tense", "Pres"),
"futr": ("Tense", "Fut"),
"1per": ("Person", "1"),
"2per": ("Person", "2"),
"3per": ("Person", "3"),
"COMP": ("Degree", "Cmp"),
"Supr": ("Degree", "Sup"),
}
POS_MAP = {
"NOUN": ["NOUN"],
"ADJ": ["ADJF", "ADJS"],
"VERB": ["VERB", "INFN", "GRND"],
"AUX": ["VERB", "INFN"],
"ADV": ["ADVB", "COMP", "PRED"],
"PRON": ["NPRO"],
"DET": ["ADJF", "NPRO"],
"NUM": ["NUMR"],
"ADP": ["PREP"],
"CCONJ": ["CONJ"],
"SCONJ": ["CONJ"],
"PART": ["PRCL"],
"INTJ": ["INTJ"],
}
def pymorphy_tag_to_ud_feats(tag_str: str) -> dict[str, str]:
feats = {}
for tok in tag_str.replace(",", " ").split():
if tok in PYMORPHY_FEATURE_MAP:
k, v = PYMORPHY_FEATURE_MAP[tok]
feats[k] = v
return feats
def parse_pymorphy_pos(tag_str: str) -> str:
return tag_str.split()[0].split(",")[0]
def score_candidate(candidate: dict, gold: dict) -> dict:
cand_lemma = candidate["lemma"]
cand_tag = candidate["tag"]
lemma_match = cand_lemma == gold["lemma"]
pymorphy_pos = parse_pymorphy_pos(cand_tag)
allowed_pos = POS_MAP.get(gold["upos"], [gold["upos"]])
pos_match = pymorphy_pos in allowed_pos
cand_feats = pymorphy_tag_to_ud_feats(cand_tag)
gold_feats = gold.get("feats", {})
reward = 0
if lemma_match:
reward += 1
if pos_match:
reward += 1
feature_matches = {}
for feat_name, gold_value in gold_feats.items():
cand_value = cand_feats.get(feat_name)
match = cand_value == gold_value
feature_matches[feat_name] = {
"candidate": cand_value,
"gold": gold_value,
"match": match,
}
if match:
reward += 1
return {
"reward": reward,
"lemma_match": lemma_match,
"pos_match": pos_match,
"candidate_feats": cand_feats,
"gold_feats": gold_feats,
"feature_matches": feature_matches,
}
class LinguaEnvironment(Environment):
def __init__(self):
super().__init__()
with open(data_path, "r", encoding="utf-8") as f:
self.examples = [json.loads(line) for line in f]
self.current = None
self._state = LinguaState(gold={})
def reset(self) -> LinguaObservation:
self.current = random.choice(self.examples)
self._state = LinguaState(
episode_id=str(uuid.uuid4()),
step_count=0,
gold=self.current["gold"],
)
return LinguaObservation(
sentence=list(self.current["sentence"]),
target_index=int(self.current["target_index"]),
word=str(self.current["word"]),
candidate_parses=list(self.current["candidate_parses"]),
)
def step(self, action: LinguaAction) -> LinguaObservation:
if self.current is None:
raise RuntimeError("Environment must be reset before step().")
choice = int(action.choice)
if choice < 0 or choice >= len(self.current["candidate_parses"]):
raise ValueError(f"Invalid choice index: {choice}")
self._state.step_count += 1
chosen = self.current["candidate_parses"][choice]
details = score_candidate(chosen, self.current["gold"])
self._state.gold = self.current["gold"]
obs = LinguaObservation(
sentence=list(self.current["sentence"]),
target_index=int(self.current["target_index"]),
word=str(self.current["word"]),
candidate_parses=list(self.current["candidate_parses"]),
)
obs.reward = float(details["reward"])
obs.done = True
obs.metadata = {
"chosen": chosen,
"gold": self.current["gold"],
"reward_breakdown": details,
}
return obs
@property
def state(self) -> LinguaState:
return self._state