Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import time | |
| import torch | |
| import gradio as gr | |
| from datetime import datetime, timedelta | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| # ---------------------------- | |
| # Config and defaults | |
| # ---------------------------- | |
| MODEL_OPTIONS = { | |
| "Phi-3.5 Mini Instruct (4B)": "microsoft/Phi-3.5-mini-instruct", | |
| "Phi-3.5 MoE Instruct (42B)": "microsoft/Phi-3.5-MoE-instruct", | |
| "Phi-3 Mini 4K Instruct (4B)": "microsoft/Phi-3-mini-4k-instruct", | |
| "Phi-3 Mini 128K Instruct (4B)": "microsoft/Phi-3-mini-128k-instruct" | |
| } | |
| EXAMPLES = [ | |
| "Read this short passage and tell me the main idea in your own words.", | |
| "I’ll teach you a concept. Repeat it back to me in simple words: Solar panels turn sunlight into electricity.", | |
| "Here’s a new phrase: 'The sea is calm today.' Try saying it in Basque.", | |
| "I’ll give you a style: noir detective. Write one sentence about Gros in that style.", | |
| "Read a Shakespeare quote and tell me what you think it means.", | |
| "Read a Dickens passage and explain how it feels.", | |
| "Translate a short poem line into another language, then tell me what mood it carries.", | |
| "Summarize this text in two sentences, then say if it sounds optimistic or pessimistic." | |
| ] | |
| DEFAULT_PROFILE = { | |
| "name": "Learner", | |
| "style": ["concise", "reflective", "Basque context where relevant"], | |
| "goals": ["conversation-first learning", "daily language blocks", "CPU-only"] | |
| } | |
| DEFAULT_BLOCKS = [ | |
| {"type": "style", "rule": "Ask clarifying questions when uncertain."}, | |
| {"type": "vocab", "rule": "Use sensory detail + local place anchoring when writing creatively."}, | |
| {"type": "conversation", "rule": "Keep answers short and specific; avoid repeating conclusions."} | |
| ] | |
| BLOCKS_FILE = "blocks.json" | |
| # ---------------------------- | |
| # Persistence helpers | |
| # ---------------------------- | |
| def load_blocks(): | |
| if os.path.exists(BLOCKS_FILE): | |
| try: | |
| with open(BLOCKS_FILE, "r", encoding="utf-8") as f: | |
| return json.load(f) | |
| except Exception: | |
| pass | |
| return {"user_profile": DEFAULT_PROFILE, "language_blocks": DEFAULT_BLOCKS} | |
| def save_blocks(data): | |
| with open(BLOCKS_FILE, "w", encoding="utf-8") as f: | |
| json.dump(data, f, ensure_ascii=False, indent=2) | |
| def add_block(data, rule_text, block_type="conversation"): | |
| if not rule_text.strip(): | |
| return data | |
| entry = { | |
| "type": block_type, | |
| "rule": rule_text.strip(), | |
| "validated": True, | |
| "review_schedule": schedule_reviews() | |
| } | |
| data["language_blocks"].append(entry) | |
| save_blocks(data) | |
| return data | |
| def schedule_reviews(): | |
| today = datetime.utcnow().date() | |
| return [ | |
| str(today + timedelta(days=1)), | |
| str(today + timedelta(days=3)), | |
| str(today + timedelta(days=7)) | |
| ] | |
| # ---------------------------- | |
| # Model loading (CPU-only) | |
| # ---------------------------- | |
| _loaded = {} # cache | |
| def load_model(model_id): | |
| if model_id in _loaded: | |
| return _loaded[model_id] | |
| tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| trust_remote_code=True, | |
| torch_dtype=torch.float32 # CPU friendly | |
| ) | |
| model.eval() | |
| _loaded[model_id] = (tokenizer, model) | |
| return tokenizer, model | |
| # ---------------------------- | |
| # Prompt construction | |
| # ---------------------------- | |
| def format_blocks(blocks): | |
| return "\n".join([f"- [{b.get('type','rule')}] {b.get('rule','')}" for b in blocks]) | |
| SYSTEM_TEMPLATE = """You are a conversation-first learning chatbot. | |
| Follow the user's style and goals, reinforce today's blocks, and confirm corrections. | |
| User style: {style} | |
| Goals: {goals} | |
| Active language blocks: | |
| {blocks} | |
| Guidelines: | |
| - Keep responses concise and specific. | |
| - Ask for clarification when needed. | |
| - Extract new patterns only when validated by the user. | |
| """ | |
| def build_messages(user_text, profile, blocks): | |
| system = SYSTEM_TEMPLATE.format( | |
| style=", ".join(profile.get("style", [])), | |
| goals=", ".join(profile.get("goals", [])), | |
| blocks=format_blocks(blocks) | |
| ) | |
| return [ | |
| {"role": "system", "content": system}, | |
| {"role": "user", "content": user_text} | |
| ] | |
| # ---------------------------- | |
| # Generate (with token/latency) | |
| # ---------------------------- | |
| def chat(user_text, model_label, blocks_json): | |
| # parse blocks from textarea (JSON or fallback lines) | |
| data = load_blocks() | |
| blocks = parse_blocks_editor(blocks_json, data.get("language_blocks", [])) | |
| model_id = MODEL_OPTIONS[model_label] | |
| tokenizer, model = load_model(model_id) | |
| messages = build_messages(user_text, data["user_profile"], blocks) | |
| inputs = tokenizer.apply_chat_template( | |
| messages, | |
| add_generation_prompt=True, | |
| tokenize=True, | |
| return_tensors="pt" | |
| ).to("cpu") | |
| start = time.time() | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=200, | |
| do_sample=False, | |
| use_cache=False # Avoid DynamicCache mismatch issues on some setups | |
| ) | |
| latency = time.time() - start | |
| # slice out the generated continuation | |
| gen_text = tokenizer.decode( | |
| outputs[0][inputs["input_ids"].shape[-1]:], | |
| skip_special_tokens=True | |
| ).strip() | |
| # token counts | |
| input_tokens = int(inputs["input_ids"].shape[-1]) | |
| output_tokens = int(outputs[0].shape[-1] - inputs["input_ids"].shape[-1]) | |
| metrics = f"Input tokens: {input_tokens} | Output tokens: {output_tokens} | Latency: {latency:.2f}s" | |
| return gen_text, metrics | |
| def parse_blocks_editor(text, fallback): | |
| """ | |
| Accept either: | |
| - JSON array of blocks | |
| - Plain text lines ("type: rule") | |
| """ | |
| if not text or not text.strip(): | |
| return fallback | |
| text = text.strip() | |
| try: | |
| parsed = json.loads(text) | |
| if isinstance(parsed, list): | |
| return parsed | |
| except Exception: | |
| pass | |
| # Fallback: each non-empty line becomes a block | |
| blocks = [] | |
| for line in text.splitlines(): | |
| line = line.strip() | |
| if not line: | |
| continue | |
| if ":" in line: | |
| t, r = line.split(":", 1) | |
| blocks.append({"type": t.strip(), "rule": r.strip()}) | |
| else: | |
| blocks.append({"type": "rule", "rule": line}) | |
| return blocks or fallback | |
| # ---------------------------- | |
| # Reflection: extract new rule | |
| # ---------------------------- | |
| REFLECT_TEMPLATE = """From the user's last message and your reply, extract ONE reusable conversation rule. | |
| Return only the rule, no preface, max 20 words. | |
| Example rules: | |
| - Ask clarifying questions when uncertain. | |
| - Use sensory detail with local anchors in creative writing. | |
| - Summarize then assess tone (optimistic/pessimistic). | |
| User said: | |
| {user} | |
| Assistant replied: | |
| {assistant} | |
| Now output one new rule:""" | |
| def reflect_and_save(user_text, assistant_text, blocks_editor_value): | |
| data = load_blocks() | |
| # Propose a rule via a simple heuristic (no extra model call, keeps it lean) | |
| # If you prefer model-based reflection, you can run a generation with REFLECT_TEMPLATE. | |
| proposal = heuristic_rule(user_text, assistant_text) | |
| data = add_block(data, proposal, block_type="conversation") | |
| # Return updated blocks as pretty JSON to show in the editor | |
| pretty = json.dumps(data["language_blocks"], ensure_ascii=False, indent=2) | |
| return pretty, f"Saved rule: {proposal}" | |
| def heuristic_rule(user_text, assistant_text): | |
| # Very simple heuristic: if assistant asked a question, reinforce clarification; | |
| # otherwise, reinforce concise responses. | |
| if "?" in assistant_text: | |
| return "Ask clarifying questions when uncertain." | |
| # If user asked for style or translation, capture that | |
| low = user_text.lower() | |
| if "translate" in low: | |
| return "Confirm translation intent and target tone before translating." | |
| if "style" in low or "noir" in low: | |
| return "Confirm style constraints before writing and keep it concise." | |
| return "Keep answers short, specific, and avoid repeating conclusions." | |
| # ---------------------------- | |
| # Gradio UI | |
| # ---------------------------- | |
| def launch(): | |
| data = load_blocks() | |
| default_blocks_text = json.dumps(data["language_blocks"], ensure_ascii=False, indent=2) | |
| with gr.Blocks(title="Conversation Learning Lab (CPU)") as demo: | |
| gr.Markdown("# 🗣️ Conversation Learning Lab (CPU-friendly)") | |
| gr.Markdown("Focus on daily dialogue. Reinforce validated language blocks. Transparent tokens and latency.") | |
| with gr.Row(): | |
| model_dd = gr.Dropdown( | |
| label="Choose a model", | |
| choices=list(MODEL_OPTIONS.keys()), | |
| value="Phi-3.5 Mini Instruct (4B)" | |
| ) | |
| with gr.Row(): | |
| user_in = gr.Textbox( | |
| label="Your message", | |
| placeholder="Start a conversation or choose an example below...", | |
| lines=3 | |
| ) | |
| with gr.Row(): | |
| blocks_editor = gr.Textbox( | |
| label="Today's blocks (JSON array or 'type: rule' lines)", | |
| value=default_blocks_text, | |
| lines=10 | |
| ) | |
| with gr.Row(): | |
| generate_btn = gr.Button("Generate (CPU)") | |
| reflect_btn = gr.Button("Reflect & Save Rule") | |
| with gr.Row(): | |
| output = gr.Textbox(label="Assistant", lines=8) | |
| with gr.Row(): | |
| metrics = gr.Markdown("") | |
| gr.Markdown("### 🧪 Try an example prompt:") | |
| gr.Examples( | |
| examples=EXAMPLES, | |
| inputs=user_in | |
| ) | |
| # Wire up events | |
| generate_btn.click( | |
| fn=chat, | |
| inputs=[user_in, model_dd, blocks_editor], | |
| outputs=[output, metrics] | |
| ) | |
| reflect_btn.click( | |
| fn=reflect_and_save, | |
| inputs=[user_in, output, blocks_editor], | |
| outputs=[blocks_editor, metrics] | |
| ) | |
| demo.launch() | |
| if __name__ == "__main__": | |
| launch() | |