File size: 2,476 Bytes
9cd6cbe
746e9b7
9cd6cbe
773b204
9cd6cbe
c5552a7
9e5bead
c5552a7
9cd6cbe
c5552a7
2355649
 
 
746e9b7
2355649
 
9cd6cbe
 
746e9b7
 
c5552a7
9cd6cbe
89812f5
9cd6cbe
 
 
 
 
 
 
 
d4a67f1
 
9cd6cbe
 
9e5bead
9cd6cbe
 
 
d4a67f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9cd6cbe
d4a67f1
9cd6cbe
d4a67f1
 
 
9cd6cbe
 
 
d4a67f1
9cd6cbe
 
d4a67f1
746e9b7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import os

# === Model ID and Token ===
model_id = "TrabbyPatty/mistral-7b-instruct-finetuned-flashcards-4bit"
hf_token = os.getenv("alluse")  # Hugging Face token from Space secrets

# === Load tokenizer & model with authentication ===
tokenizer = AutoTokenizer.from_pretrained(
    model_id,
    token=hf_token,
    use_fast=False   # βœ… force slow tokenizer (fixes JSON error)
)

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",       # let HF map to GPU/CPU
    torch_dtype=torch.float16,
    token=hf_token
)

# === SYSTEM MESSAGE ===
SYSTEM_MESSAGE = """<<SYS>>
You are a strict flashcard generator.
- Only extract information from the input.
- Do NOT add outside knowledge, assumptions, or details not mentioned in the input.
- Always follow the requested format exactly.
<</SYS>>"""

def generate_flashcards(user_input, max_new_tokens=600, temperature=0.5):
    # Format the prompt with system + user input
    prompt = (
        f"<s>[INST] {SYSTEM_MESSAGE}\n\n"
        f"Create flashcards strictly using only the information provided.\n\n"
        f"Input: {user_input}[/INST]\nOutput:"
    )

    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            do_sample=False,
            repetition_penalty=1.05,
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )

    response = tokenizer.decode(outputs[0], skip_special_tokens=True)

    # Extract only the Output section
    if "Output:" in response:
        final_answer = response.split("Output:")[-1].strip()
    else:
        final_answer = response.strip()

    return final_answer


# βœ… Gradio UI
demo = gr.Interface(
    fn=generate_flashcards,
    inputs=[
        gr.Textbox(label="Enter study text", lines=8, placeholder="Paste your study material here..."),
        gr.Slider(100, 1000, value=600, step=50, label="Max New Tokens"),
        gr.Slider(0.1, 1.0, value=0.5, step=0.1, label="Temperature"),
    ],
    outputs="text",
    title="Flashcard Generator (Mistral-7B LoRA)",
    description="Paste study material and generate flashcards. Model strictly extracts only from input."
)

if __name__ == "__main__":
    demo.launch()