File size: 4,048 Bytes
d1abcca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
"""
Agent 2: Draft Generator
------------------------
Takes the structured context from Agent 1 and rewrites the
original text in natural language while keeping 100% of the
factual content intact.

Uses Mistral-7B-Instruct via HF Inference API.
"""

import os
import logging

from huggingface_hub import InferenceClient

logger = logging.getLogger(__name__)


class DraftGenerator:
    """Second link — produce a coherent, natural-sounding draft."""

    def __init__(self, hf_token=None):
        self.token = hf_token or os.getenv("HF_TOKEN", "")
        self.client = InferenceClient(token=self.token)
        self.model = "mistralai/Mistral-7B-Instruct-v0.3"

    # ------------------------------------------------------------------
    # public api
    # ------------------------------------------------------------------

    def generate(self, context: dict) -> str:
        """
        Parameters
        ----------
        context : dict
            The output of SemanticAnalyzer.analyze() — must contain
            'original_text' and 'analysis' keys.

        Returns
        -------
        str  —  The rewritten draft text.
        """
        original = context.get("original_text", "")
        analysis = context.get("analysis", {})
        tone = analysis.get("tone", "neutral")
        audience = analysis.get("target_audience", "general audience")
        topic = analysis.get("core_topic", "the given topic")

        logger.info("draft generator: rewriting %d chars (tone=%s)", len(original), tone)

        prompt = self._build_prompt(original, tone, audience, topic)

        try:
            draft = self.client.text_generation(
                prompt,
                model=self.model,
                max_new_tokens=1024,
                temperature=0.6,       # moderate creativity
                top_p=0.9,
            )
            draft = self._cleanup(draft)
        except Exception as exc:
            logger.error("draft generation failed: %s — returning original", exc)
            draft = original  # safe fallback

        return draft

    # ------------------------------------------------------------------
    # internals
    # ------------------------------------------------------------------

    def _build_prompt(self, text, tone, audience, topic):
        return (
            "[INST] You are a skilled writer. Rewrite the text below in clear, "
            "natural language. Follow these rules strictly:\n\n"
            "1. Preserve ALL factual content — do not add or remove information.\n"
            "2. Keep the same overall structure and flow.\n"
            f"3. Match the tone: {tone}\n"
            f"4. Write for this audience: {audience}\n"
            f"5. The core topic is: {topic}\n"
            "6. Use natural phrasing but you can still sound polished at this stage.\n"
            "7. Return ONLY the rewritten text, nothing else.\n\n"
            f"Original text:\n\"{text}\"\n\n"
            "Rewritten version: [/INST]"
        )

    @staticmethod
    def _cleanup(raw: str) -> str:
        """Strip stray quotes, whitespace, markdown fences."""
        text = raw.strip()
        # remove markdown code fences if the model wrapped it
        if text.startswith("```"):
            lines = text.split("\n")
            lines = [l for l in lines if not l.strip().startswith("```")]
            text = "\n".join(lines).strip()
        # strip surrounding quotes
        if text.startswith('"') and text.endswith('"'):
            text = text[1:-1]
        return text


# quick test
if __name__ == "__main__":
    from semantic_analyzer import SemanticAnalyzer

    sa = SemanticAnalyzer()
    dg = DraftGenerator()

    sample = (
        "The rapid advancement of artificial intelligence presents both "
        "opportunities and challenges for modern society. It is imperative "
        "that we consider the ethical implications of these technologies."
    )
    ctx = sa.analyze(sample)
    draft = dg.generate(ctx)
    print("=== DRAFT ===")
    print(draft)