| from typing import Dict, List, Union, Optional |
| from llms import LLM |
| import json |
| import re |
|
|
| def pos_tagging( |
| text: str, |
| model: str = "en_core_web_sm", |
| use_llm: bool = False, |
| custom_instructions: str = "" |
| ) -> Dict[str, List[Union[str, List[str]]]]: |
| """ |
| Perform Part-of-Speech tagging on the input text using either LLM or traditional models. |
| |
| Args: |
| text: The input text to tag |
| model: The model to use for tagging (e.g., 'en_core_web_sm', 'gpt-4', 'gemini-pro') |
| use_llm: Whether to use LLM for more accurate but slower POS tagging |
| custom_instructions: Custom instructions for LLM-based tagging |
| |
| Returns: |
| A dictionary containing 'tokens' and 'tags' lists |
| """ |
| if not text.strip(): |
| return {"tokens": [], "tags": []} |
| |
| if use_llm: |
| return _pos_tagging_with_llm(text, model, custom_instructions) |
| else: |
| return _pos_tagging_traditional(text, model) |
|
|
| def _extract_json_array(text: str) -> str: |
| """Extract JSON array from text, handling various formats.""" |
| import re |
| |
| |
| json_match = re.search(r'\[\s*\{.*\}\s*\]', text, re.DOTALL) |
| if json_match: |
| return json_match.group(0) |
| |
| |
| start = text.find('[') |
| end = text.rfind(']') |
| if start >= 0 and end > start: |
| return text[start:end+1] |
| |
| return text |
|
|
| def _pos_tagging_with_llm( |
| text: str, |
| model_name: str, |
| custom_instructions: str = "" |
| ) -> Dict[str, List[str]]: |
| """Use LLM for more accurate and flexible POS tagging.""" |
| |
| prompt = """Analyze the following text and provide Part-of-Speech (POS) tags for each token. |
| Return the result as a JSON array of objects with 'token' and 'tag' keys. |
| |
| Use standard Universal Dependencies POS tags: |
| - ADJ: adjective |
| - ADP: adposition |
| - ADV: adverb |
| - AUX: auxiliary verb |
| - CONJ: coordinating conjunction |
| - DET: determiner |
| - INTJ: interjection |
| - NOUN: noun |
| - NUM: numeral |
| - PART: particle |
| - PRON: pronoun |
| - PROPN: proper noun |
| - PUNCT: punctuation |
| - SCONJ: subordinating conjunction |
| - SYM: symbol |
| - VERB: verb |
| - X: other |
| |
| Example output format: |
| [ |
| {"token": "Hello", "tag": "INTJ"}, |
| {"token": "world", "tag": "NOUN"}, |
| {"token": ".", "tag": "PUNCT"} |
| ] |
| |
| Text to analyze: |
| """ |
| |
| if custom_instructions: |
| prompt = f"{custom_instructions}\n\n{prompt}" |
| |
| prompt += f'"{text}"' |
| |
| try: |
| |
| llm = LLM(model=model_name, temperature=0.1, max_tokens=2000) |
| |
| |
| response = llm.generate(prompt) |
| print(f"LLM Raw Response: {response[:500]}...") |
| |
| if not response.strip(): |
| raise ValueError("Empty response from LLM") |
| |
| |
| json_str = _extract_json_array(response) |
| if not json_str: |
| raise ValueError("No JSON array found in response") |
| |
| |
| try: |
| pos_tags = json.loads(json_str) |
| except json.JSONDecodeError as e: |
| |
| json_str = json_str.replace("'", '"') |
| json_str = re.sub(r'(\w+):', r'"\1":', json_str) |
| pos_tags = json.loads(json_str) |
| |
| |
| if not isinstance(pos_tags, list): |
| raise ValueError(f"Expected list, got {type(pos_tags).__name__}") |
| |
| tokens = [] |
| tags = [] |
| |
| for item in pos_tags: |
| if not isinstance(item, dict): |
| continue |
| |
| token = item.get('token', '') |
| tag = item.get('tag', '') |
| |
| if token and tag: |
| tokens.append(str(token).strip()) |
| tags.append(str(tag).strip()) |
| |
| if not tokens or not tags: |
| raise ValueError("No valid tokens and tags found in response") |
| |
| return { |
| 'tokens': tokens, |
| 'tags': tags |
| } |
| |
| except Exception as e: |
| print(f"Error in LLM POS tagging: {str(e)}") |
| print(f"Falling back to traditional POS tagging...") |
| return _pos_tagging_traditional(text, "en_core_web_sm") |
|
|
| def _pos_tagging_traditional(text: str, model: str) -> Dict[str, List[str]]: |
| """Use traditional POS tagging models.""" |
| try: |
| import spacy |
| |
| |
| try: |
| nlp = spacy.load(model) |
| except OSError: |
| |
| nlp = spacy.load("en_core_web_sm") |
| |
| |
| doc = nlp(text) |
| |
| |
| tokens = [] |
| tags = [] |
| for token in doc: |
| tokens.append(token.text) |
| tags.append(token.pos_) |
| |
| return { |
| 'tokens': tokens, |
| 'tags': tags |
| } |
| |
| except Exception as e: |
| print(f"Error in traditional POS tagging: {str(e)}") |
| return {"tokens": [], "tags": []} |
|
|