Spaces:
Running
Running
| """ | |
| Agent 2: Draft Generator | |
| ------------------------ | |
| Takes the structured context from Agent 1 and rewrites the | |
| original text in natural language while keeping 100% of the | |
| factual content intact. | |
| Uses Mistral-7B-Instruct via HF Inference API. | |
| """ | |
| import os | |
| import logging | |
| from huggingface_hub import InferenceClient | |
| logger = logging.getLogger(__name__) | |
| class DraftGenerator: | |
| """Second link β produce a coherent, natural-sounding draft.""" | |
| def __init__(self, hf_token=None): | |
| self.token = hf_token or os.getenv("HF_TOKEN", "") | |
| self.client = InferenceClient(token=self.token) | |
| self.model = "mistralai/Mistral-7B-Instruct-v0.3" | |
| # ------------------------------------------------------------------ | |
| # public api | |
| # ------------------------------------------------------------------ | |
| def generate(self, context: dict) -> str: | |
| """ | |
| Parameters | |
| ---------- | |
| context : dict | |
| The output of SemanticAnalyzer.analyze() β must contain | |
| 'original_text' and 'analysis' keys. | |
| Returns | |
| ------- | |
| str β The rewritten draft text. | |
| """ | |
| original = context.get("original_text", "") | |
| analysis = context.get("analysis", {}) | |
| tone = analysis.get("tone", "neutral") | |
| audience = analysis.get("target_audience", "general audience") | |
| topic = analysis.get("core_topic", "the given topic") | |
| logger.info("draft generator: rewriting %d chars (tone=%s)", len(original), tone) | |
| prompt = self._build_prompt(original, tone, audience, topic) | |
| try: | |
| draft = self.client.text_generation( | |
| prompt, | |
| model=self.model, | |
| max_new_tokens=1024, | |
| temperature=0.6, # moderate creativity | |
| top_p=0.9, | |
| ) | |
| draft = self._cleanup(draft) | |
| except Exception as exc: | |
| logger.error("draft generation failed: %s β returning original", exc) | |
| draft = original # safe fallback | |
| return draft | |
| # ------------------------------------------------------------------ | |
| # internals | |
| # ------------------------------------------------------------------ | |
| def _build_prompt(self, text, tone, audience, topic): | |
| return ( | |
| "[INST] You are a skilled writer. Rewrite the text below in clear, " | |
| "natural language. Follow these rules strictly:\n\n" | |
| "1. Preserve ALL factual content β do not add or remove information.\n" | |
| "2. Keep the same overall structure and flow.\n" | |
| f"3. Match the tone: {tone}\n" | |
| f"4. Write for this audience: {audience}\n" | |
| f"5. The core topic is: {topic}\n" | |
| "6. Use natural phrasing but you can still sound polished at this stage.\n" | |
| "7. Return ONLY the rewritten text, nothing else.\n\n" | |
| f"Original text:\n\"{text}\"\n\n" | |
| "Rewritten version: [/INST]" | |
| ) | |
| def _cleanup(raw: str) -> str: | |
| """Strip stray quotes, whitespace, markdown fences.""" | |
| text = raw.strip() | |
| # remove markdown code fences if the model wrapped it | |
| if text.startswith("```"): | |
| lines = text.split("\n") | |
| lines = [l for l in lines if not l.strip().startswith("```")] | |
| text = "\n".join(lines).strip() | |
| # strip surrounding quotes | |
| if text.startswith('"') and text.endswith('"'): | |
| text = text[1:-1] | |
| return text | |
| # quick test | |
| if __name__ == "__main__": | |
| from semantic_analyzer import SemanticAnalyzer | |
| sa = SemanticAnalyzer() | |
| dg = DraftGenerator() | |
| sample = ( | |
| "The rapid advancement of artificial intelligence presents both " | |
| "opportunities and challenges for modern society. It is imperative " | |
| "that we consider the ethical implications of these technologies." | |
| ) | |
| ctx = sa.analyze(sample) | |
| draft = dg.generate(ctx) | |
| print("=== DRAFT ===") | |
| print(draft) | |