import spaces import gradio as gr from transformers import pipeline, TextIteratorStreamer import torch import threading import os # Load model and tokenizer model_name = os.getenv("MODEL_ID") pipe = pipeline("text-generation", model=model_name, device=0) tokenizer = pipe.tokenizer model = pipe.model # Fixed generation config MAX_TOKENS = 3000 TEMPERATURE = 0.1 TOP_P = 0.9 @spaces.GPU def respond_stream(summary, title, abstract): # Validate mandatory fields if not summary.strip() or not title.strip() or not abstract.strip(): return "❌ Error: PICOS Summary, Title, and Abstract are all required." # Build prompt prompt = ( f"Instruction: Use the following PICOS summary to evaluate the abstract.\n" f"\nPICOS Summary: {summary.strip()}" f"\n\nTitle: {title.strip()}\nAbstract: {abstract.strip()}" ) # Wrap into message for chat template messages = [{"role": "user", "content": prompt}] prompt_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) # Tokenize and prepare streamer inputs = tokenizer(prompt_text, return_tensors="pt").to("cuda") streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) generation_kwargs = dict( input_ids=inputs["input_ids"], streamer=streamer, max_new_tokens=MAX_TOKENS, temperature=TEMPERATURE, top_p=TOP_P, do_sample=True, pad_token_id=tokenizer.eos_token_id, ) thread = threading.Thread(target=model.generate, kwargs=generation_kwargs) thread.start() partial_text = "" for token in streamer: partial_text += token yield partial_text # Build Gradio interface with gr.Blocks() as demo: gr.Markdown("## Study design screener") with gr.Column(): summary = gr.Textbox(label="PICOS Summary", lines=4, placeholder="Required") title = gr.Textbox(label="Title", lines=2, placeholder="Required") abstract = gr.Textbox(label="Abstract", lines=10, placeholder="Required") output_box = gr.Textbox(label="Model Response", lines=15, interactive=False) generate_btn = gr.Button("Generate") generate_btn.click( fn=respond_stream, inputs=[summary, title, abstract], outputs=[output_box] ) # Launch the app if __name__ == "__main__": demo.launch()