TrabbyPatty commited on
Commit
d4a67f1
Β·
verified Β·
1 Parent(s): 72fdcb1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -18
app.py CHANGED
@@ -39,33 +39,59 @@ You are a strict flashcard generator.
39
  - Always follow the requested format exactly.
40
  <</SYS>>"""
41
 
42
- # === Generation function ===
43
- def generate(user_input, max_new_tokens=800, temperature=0.5):
 
 
 
 
 
 
 
 
 
44
  prompt = (
45
  f"<s>[INST] {SYSTEM_MESSAGE}\n\n"
46
- f"Create flashcards, T/F, MCQ, and study guides strictly using only the information provided.\n\n"
47
  f"Input: {user_input}[/INST]\nOutput:"
48
  )
49
- output = pipe(
50
- prompt,
51
- max_new_tokens=max_new_tokens,
52
- temperature=temperature,
53
- repetition_penalty=1.05,
54
- do_sample=True
55
- )
56
- return output[0]["generated_text"]
57
 
58
- # === Gradio UI ===
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  demo = gr.Interface(
60
- fn=generate,
61
  inputs=[
62
- gr.Textbox(label="Topic / Input", lines=6, placeholder="Paste study material here..."),
63
- gr.Slider(50, 800, value=200, step=10, label="Max New Tokens"),
64
- gr.Slider(0.1, 1.5, value=0.7, step=0.1, label="Temperature"),
65
  ],
66
  outputs="text",
67
  title="Flashcard Generator (Mistral-7B LoRA)",
68
- description="Generates flashcard-style study aids using only the provided text."
69
  )
70
 
71
- demo.launch()
 
 
39
  - Always follow the requested format exactly.
40
  <</SYS>>"""
41
 
42
+ # βœ… Load model + tokenizer
43
+ model = AutoModelForCausalLM.from_pretrained(
44
+ model_id,
45
+ torch_dtype=torch.float16,
46
+ device_map="auto",
47
+ load_in_4bit=True # helps fit on ZeroGPU
48
+ )
49
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
50
+
51
+ def generate_flashcards(user_input, max_new_tokens=600, temperature=0.5):
52
+ # Format the prompt with system + user input
53
  prompt = (
54
  f"<s>[INST] {SYSTEM_MESSAGE}\n\n"
55
+ f"Create a variety of study aids with 10 items each, strictly using only the information provided.\n\n"
56
  f"Input: {user_input}[/INST]\nOutput:"
57
  )
 
 
 
 
 
 
 
 
58
 
59
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
60
+
61
+ with torch.no_grad():
62
+ outputs = model.generate(
63
+ **inputs,
64
+ max_new_tokens=max_new_tokens,
65
+ temperature=temperature,
66
+ do_sample=False,
67
+ repetition_penalty=1.05,
68
+ pad_token_id=tokenizer.eos_token_id,
69
+ eos_token_id=tokenizer.eos_token_id,
70
+ )
71
+
72
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
73
+
74
+ # Extract only the Output section
75
+ if "Output:" in response:
76
+ final_answer = response.split("Output:")[-1].strip()
77
+ else:
78
+ final_answer = response.strip()
79
+
80
+ return final_answer
81
+
82
+
83
+ # βœ… Gradio UI
84
  demo = gr.Interface(
85
+ fn=generate_flashcards,
86
  inputs=[
87
+ gr.Textbox(label="Enter study text", lines=8, placeholder="Paste your study material here..."),
88
+ gr.Slider(100, 1000, value=600, step=50, label="Max New Tokens"),
89
+ gr.Slider(0.1, 1.0, value=0.5, step=0.1, label="Temperature"),
90
  ],
91
  outputs="text",
92
  title="Flashcard Generator (Mistral-7B LoRA)",
93
+ description="Paste study material and generate flashcards. Model strictly extracts only from input."
94
  )
95
 
96
+ if __name__ == "__main__":
97
+ demo.launch()