Spaces:
Sleeping
Sleeping
File size: 1,788 Bytes
9a395d1 bba9bff 9a395d1 4ccdbe4 abd6fbd e79949d 51a091e 4ccdbe4 51a091e 4ccdbe4 51a091e bba9bff 51a091e e79949d 51a091e c99d5db e79949d 4ccdbe4 e79949d 4ccdbe4 e79949d 4ccdbe4 e79949d 4ccdbe4 e79949d 51a091e e79949d 4ccdbe4 e79949d 9a395d1 e79949d 9a395d1 9664295 4ccdbe4 b2f98f0 e79949d 51a091e 9a395d1 4ccdbe4 9664295 e79949d 5ee5ebc e79949d 368865e e79949d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
from transformers import pipeline
import gradio as gr
import torch
import os
# ===== SMART DEVICE CONFIGURATION =====
def get_best_device():
if torch.cuda.is_available():
torch.backends.cudnn.benchmark = True
return 0, torch.float16 # GPU with half precision
return -1, torch.float32 # CPU fallback
device, dtype = get_best_device()
device_name = "GPU: " + torch.cuda.get_device_name(0) if device == 0 else "CPU"
print(f"⚡ Running on: {device_name}")
# ===== ERROR-PROOF MODEL LOADING =====
try:
# Correct pipeline configuration (fixed trust_remote_code)
model = pipeline(
"text-generation",
model="google/gemma-2b-it",
device=device,
torch_dtype=dtype,
model_kwargs={
"low_cpu_mem_usage": True,
"trust_remote_code": True # Correct placement
}
)
# Pre-warm model
model("Warmup", max_new_tokens=1)
except Exception as e:
# Simplified fallback (removes duplicate trust_remote_code)
model = pipeline(
"text-generation",
model="google/gemma-2b-it",
device=device,
torch_dtype=dtype
)
# ===== OPTIMIZED GENERATION =====
def generate(prompt):
try:
return model(
prompt,
max_new_tokens=60,
temperature=0.2,
do_sample=False,
pad_token_id=model.tokenizer.eos_token_id
)[0]['generated_text']
except Exception as e:
return f"⚠️ Error: {str(e)}"
# ===== SIMPLE INTERFACE =====
with gr.Blocks() as demo:
gr.Markdown("## Ask anything (1-2 second responses)")
input = gr.Textbox(label="Your question")
output = gr.Textbox(label="Answer")
input.submit(generate, input, output)
demo.launch(server_name="0.0.0.0") |