Spaces:
Runtime error
Runtime error
File size: 2,476 Bytes
9cd6cbe 746e9b7 9cd6cbe 773b204 9cd6cbe c5552a7 9e5bead c5552a7 9cd6cbe c5552a7 2355649 746e9b7 2355649 9cd6cbe 746e9b7 c5552a7 9cd6cbe 89812f5 9cd6cbe d4a67f1 9cd6cbe 9e5bead 9cd6cbe d4a67f1 9cd6cbe d4a67f1 9cd6cbe d4a67f1 9cd6cbe d4a67f1 9cd6cbe d4a67f1 746e9b7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import os
# === Model ID and Token ===
model_id = "TrabbyPatty/mistral-7b-instruct-finetuned-flashcards-4bit"
hf_token = os.getenv("alluse") # Hugging Face token from Space secrets
# === Load tokenizer & model with authentication ===
tokenizer = AutoTokenizer.from_pretrained(
model_id,
token=hf_token,
use_fast=False # β
force slow tokenizer (fixes JSON error)
)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto", # let HF map to GPU/CPU
torch_dtype=torch.float16,
token=hf_token
)
# === SYSTEM MESSAGE ===
SYSTEM_MESSAGE = """<<SYS>>
You are a strict flashcard generator.
- Only extract information from the input.
- Do NOT add outside knowledge, assumptions, or details not mentioned in the input.
- Always follow the requested format exactly.
<</SYS>>"""
def generate_flashcards(user_input, max_new_tokens=600, temperature=0.5):
# Format the prompt with system + user input
prompt = (
f"<s>[INST] {SYSTEM_MESSAGE}\n\n"
f"Create flashcards strictly using only the information provided.\n\n"
f"Input: {user_input}[/INST]\nOutput:"
)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
do_sample=False,
repetition_penalty=1.05,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract only the Output section
if "Output:" in response:
final_answer = response.split("Output:")[-1].strip()
else:
final_answer = response.strip()
return final_answer
# β
Gradio UI
demo = gr.Interface(
fn=generate_flashcards,
inputs=[
gr.Textbox(label="Enter study text", lines=8, placeholder="Paste your study material here..."),
gr.Slider(100, 1000, value=600, step=50, label="Max New Tokens"),
gr.Slider(0.1, 1.0, value=0.5, step=0.1, label="Temperature"),
],
outputs="text",
title="Flashcard Generator (Mistral-7B LoRA)",
description="Paste study material and generate flashcards. Model strictly extracts only from input."
)
if __name__ == "__main__":
demo.launch()
|