File size: 1,788 Bytes
9a395d1
bba9bff
9a395d1
4ccdbe4
abd6fbd
e79949d
51a091e
4ccdbe4
51a091e
4ccdbe4
51a091e
bba9bff
51a091e
e79949d
51a091e
c99d5db
e79949d
4ccdbe4
e79949d
4ccdbe4
e79949d
4ccdbe4
 
 
 
 
e79949d
4ccdbe4
 
e79949d
51a091e
 
e79949d
4ccdbe4
e79949d
 
 
 
 
 
 
9a395d1
e79949d
9a395d1
9664295
4ccdbe4
b2f98f0
e79949d
51a091e
 
9a395d1
4ccdbe4
9664295
e79949d
5ee5ebc
e79949d
 
 
 
 
 
368865e
e79949d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from transformers import pipeline
import gradio as gr
import torch
import os

# ===== SMART DEVICE CONFIGURATION =====
def get_best_device():
    if torch.cuda.is_available():
        torch.backends.cudnn.benchmark = True
        return 0, torch.float16  # GPU with half precision
    return -1, torch.float32  # CPU fallback

device, dtype = get_best_device()
device_name = "GPU: " + torch.cuda.get_device_name(0) if device == 0 else "CPU"
print(f"⚡ Running on: {device_name}")

# ===== ERROR-PROOF MODEL LOADING =====
try:
    # Correct pipeline configuration (fixed trust_remote_code)
    model = pipeline(
        "text-generation",
        model="google/gemma-2b-it",
        device=device,
        torch_dtype=dtype,
        model_kwargs={
            "low_cpu_mem_usage": True,
            "trust_remote_code": True  # Correct placement
        }
    )
    
    # Pre-warm model
    model("Warmup", max_new_tokens=1)
    
except Exception as e:
    # Simplified fallback (removes duplicate trust_remote_code)
    model = pipeline(
        "text-generation",
        model="google/gemma-2b-it",
        device=device,
        torch_dtype=dtype
    )

# ===== OPTIMIZED GENERATION =====
def generate(prompt):
    try:
        return model(
            prompt,
            max_new_tokens=60,
            temperature=0.2,
            do_sample=False,
            pad_token_id=model.tokenizer.eos_token_id
        )[0]['generated_text']
    except Exception as e:
        return f"⚠️ Error: {str(e)}"

# ===== SIMPLE INTERFACE =====
with gr.Blocks() as demo:
    gr.Markdown("## Ask anything (1-2 second responses)")
    input = gr.Textbox(label="Your question")
    output = gr.Textbox(label="Answer")
    input.submit(generate, input, output)

demo.launch(server_name="0.0.0.0")