File size: 1,447 Bytes
a9408de
f65b360
 
 
741290e
a9408de
f65b360
 
a9408de
 
f65b360
a9408de
f65b360
 
 
 
 
 
 
 
 
 
 
 
a9408de
f65b360
 
 
 
 
 
 
 
 
 
 
741290e
f65b360
a9408de
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr

model_id = "distilbert/distilgpt2"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32)
model.eval()

def generate_text(prompt, temperature, top_k, top_p, max_tokens, repetition_penalty):
    inputs = tokenizer(prompt, return_tensors="pt")
    with torch.no_grad():
        output = model.generate(
            **inputs,
            max_new_tokens=max_tokens,
            do_sample=True,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            repetition_penalty=repetition_penalty,
            pad_token_id=tokenizer.eos_token_id
        )
    return tokenizer.decode(output[0], skip_special_tokens=True)

gr.Interface(
    fn=generate_text,
    inputs=[
        gr.Textbox(label="Prompt", placeholder="Type something here...", lines=4),
        gr.Slider(0.1, 1.5, value=1.0, step=0.1, label="Temperature"),
        gr.Slider(1, 100, value=50, step=1, label="Top-K"),
        gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-P"),
        gr.Slider(10, 512, value=128, step=1, label="Max New Tokens"),
        gr.Slider(0.5, 2.0, value=1.0, step=0.1, label="Repetition Penalty")
    ],
    outputs=gr.Textbox(label="Generated Text"),
    title="🧠 AlphaMindQ Fork distilbert/distilgpt2",
    theme="default"
).launch()