|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
Minimal end-to-end example for sui-1-24b summarization. |
|
|
|
|
|
Usage: |
|
|
# Summarize a file |
|
|
uv run example.py document.txt |
|
|
|
|
|
# Summarize inline text |
|
|
uv run example.py --text "Your long text here..." |
|
|
|
|
|
# With custom parameters |
|
|
uv run example.py document.txt --words 300 --tags 8 --language en |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import hashlib |
|
|
import json |
|
|
import re |
|
|
import sys |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser( |
|
|
description="Summarize text using sui-1-24b with source grounding", |
|
|
formatter_class=argparse.RawDescriptionHelpFormatter, |
|
|
epilog=__doc__, |
|
|
) |
|
|
parser.add_argument("input", nargs="?", help="Input file path (or use --text)") |
|
|
parser.add_argument("--text", "-t", help="Input text directly") |
|
|
parser.add_argument("--words", "-w", type=int, default=250, help="Target word count (default: 400)") |
|
|
parser.add_argument("--tags", "-n", type=int, default=4, help="Number of XML tags to cite (default: 10)") |
|
|
parser.add_argument("--language", "-l", default="en", choices=["en", "de", "es", "fr", "it"], help="Language (default: en)") |
|
|
parser.add_argument("--model", "-m", default="ellamind/sui-1-24b", help="Model path or HF repo") |
|
|
parser.add_argument("--tensor-parallel", "-tp", type=int, default=1, help="Tensor parallel size (default: 1)") |
|
|
parser.add_argument("--raw", action="store_true", help="Print raw JSON output instead of formatted") |
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
if args.text: |
|
|
text = args.text |
|
|
elif args.input: |
|
|
text = Path(args.input).read_text() |
|
|
else: |
|
|
parser.error("Provide input file or --text") |
|
|
|
|
|
|
|
|
import spacy |
|
|
from vllm import LLM, SamplingParams |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
spacy_models = { |
|
|
"en": "en_core_web_sm", |
|
|
"de": "de_core_news_sm", |
|
|
"es": "es_core_news_sm", |
|
|
"fr": "fr_core_news_sm", |
|
|
"it": "it_core_news_sm", |
|
|
} |
|
|
try: |
|
|
nlp = spacy.load(spacy_models[args.language]) |
|
|
except OSError: |
|
|
print(f"Error: spaCy model '{spacy_models[args.language]}' not found.") |
|
|
print(f"For English, this should be bundled automatically.") |
|
|
print(f"For other languages, install the model first:") |
|
|
print(f" pip install https://github.com/explosion/spacy-models/releases/download/{spacy_models[args.language]}-3.8.0/{spacy_models[args.language]}-3.8.0-py3-none-any.whl") |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
print("Tagging sentences...") |
|
|
doc = nlp(text) |
|
|
tagged_text = "" |
|
|
tag_mapping = {} |
|
|
|
|
|
for i, sent in enumerate(doc.sents): |
|
|
sentence = sent.text.strip() |
|
|
if sentence: |
|
|
tag = hashlib.md5(f"{i}_{sentence[:50]}".encode()).hexdigest()[:8] |
|
|
tag_mapping[tag] = sentence |
|
|
tagged_text += f"<{tag}>{sentence}</{tag}>" |
|
|
|
|
|
print(f"Tagged {len(tag_mapping)} sentences") |
|
|
|
|
|
|
|
|
language_names = {"en": "English", "de": "German", "es": "Spanish", "fr": "French", "it": "Italian"} |
|
|
prompt = f"""You are a professional summarizer, following all given instructions with the utmost care. |
|
|
|
|
|
<text> |
|
|
{tagged_text} |
|
|
</text> |
|
|
|
|
|
# Output Format |
|
|
The output must be in JSON format with the following structure: |
|
|
1. A "structure" string containing your thoughts about the content and structure of the summary |
|
|
2. An "xml_tags" list containing the XML tag identifiers from the tagged text (e.g., "<a1b2c3d4>") |
|
|
3. A "summary" string containing the actual summary with inline XML tag references |
|
|
|
|
|
# Instructions |
|
|
1. Start by thinking about and explaining the structure and content of your summary. Select {args.tags} XML tags from the tagged text that capture the most significant data and facts. |
|
|
2. Begin with an executive summary introducing the title, author (if available), and key findings. |
|
|
3. Structure the summary in coherent paragraphs. Every paragraph should contain at least one XML tag reference. |
|
|
4. Reference XML tags inline in square brackets (e.g., [<a1b2c3d4>]) immediately after the statement they support. |
|
|
5. Each XML tag must appear exactly once in the summary. |
|
|
6. Avoid a concluding paragraph that merely restates points. |
|
|
7. Do not use bullet points or headings unless explicitly requested. |
|
|
|
|
|
Parameters: |
|
|
- Word count (excl. XML tags): {args.words} |
|
|
- Number of XML tags: {args.tags} |
|
|
- Language: {language_names[args.language]} |
|
|
""" |
|
|
|
|
|
|
|
|
print(f"Loading model: {args.model}") |
|
|
llm = LLM( |
|
|
model=args.model, |
|
|
tensor_parallel_size=args.tensor_parallel, |
|
|
dtype="bfloat16", |
|
|
tokenizer_mode="mistral", |
|
|
trust_remote_code=True, |
|
|
limit_mm_per_prompt={"image": 0}, |
|
|
) |
|
|
|
|
|
print("Generating summary...") |
|
|
sampling_params = SamplingParams(max_tokens=4096, temperature=0.0) |
|
|
outputs = llm.chat([[{"role": "user", "content": prompt}]], sampling_params) |
|
|
result = outputs[0].outputs[0].text |
|
|
|
|
|
|
|
|
if args.raw: |
|
|
print(result) |
|
|
return |
|
|
|
|
|
try: |
|
|
|
|
|
json_match = re.search(r'\{[\s\S]*\}', result) |
|
|
if json_match: |
|
|
data = json.loads(json_match.group()) |
|
|
|
|
|
print("\n" + "=" * 60) |
|
|
print("SUMMARY") |
|
|
print("=" * 60 + "\n") |
|
|
|
|
|
summary = data.get("summary", "") |
|
|
|
|
|
|
|
|
def replace_tag(match): |
|
|
tag = match.group(1) |
|
|
source = tag_mapping.get(tag, "???") |
|
|
|
|
|
if len(source) > 80: |
|
|
source = source[:77] + "..." |
|
|
return f"[{tag}]" |
|
|
|
|
|
clean_summary = re.sub(r'\[<([a-f0-9]{8})>\]', replace_tag, summary) |
|
|
print(clean_summary) |
|
|
|
|
|
print("\n" + "-" * 60) |
|
|
print("SOURCES") |
|
|
print("-" * 60) |
|
|
|
|
|
|
|
|
|
|
|
xml_tags = data.get("xml_tags", []) |
|
|
for tag in xml_tags: |
|
|
if isinstance(tag, str): |
|
|
clean_tag = tag.strip("<>") |
|
|
elif isinstance(tag, dict) and "xml_tag" in tag: |
|
|
clean_tag = tag["xml_tag"].strip("<>") |
|
|
else: |
|
|
continue |
|
|
source = tag_mapping.get(clean_tag, "Not found") |
|
|
if len(source) > 100: |
|
|
source = source[:97] + "..." |
|
|
print(f"[{clean_tag}] {source}") |
|
|
|
|
|
else: |
|
|
print("Could not parse JSON response:") |
|
|
print(result) |
|
|
|
|
|
except json.JSONDecodeError as e: |
|
|
print(f"JSON parse error: {e}") |
|
|
print(result) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|