import spaces import gradio as gr from transformers import pipeline, TextIteratorStreamer import torch import threading import os # Load model and tokenizer model_name = os.getenv("MODEL_ID") pipe = pipeline("text-generation", model=model_name, device=0) tokenizer = pipe.tokenizer model = pipe.model # Fixed generation config MAX_TOKENS = 3000 TEMPERATURE = 0.1 TOP_P = 0.9 @spaces.GPU def respond_stream(population, intervention, comparison, outcome, study_design, summary, title, abstract): # Validate required fields if not title.strip() or not abstract.strip(): return "❌ Error: Title and Abstract are required." criteria_parts = [] if population.strip(): criteria_parts.append(f"Population of interest = {population.strip()}") if intervention.strip(): criteria_parts.append(f"Intervention/exposure of interest = {intervention.strip()}") if comparison.strip(): criteria_parts.append(f"Comparison of interest = {comparison.strip()}") if outcome.strip(): criteria_parts.append(f"Outcome of interest = {outcome.strip()}") if study_design.strip(): criteria_parts.append(f"Study design of interest = {study_design.strip()}") if not criteria_parts: return "❌ Error: At least one of the five PICOS criteria must be filled." # Build instruction section instruction = "Instruction: " + "\n".join(criteria_parts) # Construct full prompt prompt = instruction if summary.strip(): prompt += f"\n\nPICOS Summary: {summary.strip()}" prompt += f"\n\nTitle: {title.strip()}\nAbstract: {abstract.strip()}" # Wrap into message for chat template messages = [{"role": "user", "content": prompt}] prompt_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) # Tokenize and prepare streamer inputs = tokenizer(prompt_text, return_tensors="pt").to("cuda") streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) generation_kwargs = dict( input_ids=inputs["input_ids"], streamer=streamer, max_new_tokens=MAX_TOKENS, temperature=TEMPERATURE, top_p=TOP_P, do_sample=True, pad_token_id=tokenizer.eos_token_id, ) thread = threading.Thread(target=model.generate, kwargs=generation_kwargs) thread.start() partial_text = "" for token in streamer: partial_text += token yield partial_text # Build Gradio interface with gr.Blocks() as demo: gr.Markdown("## PICO screener") with gr.Column(): population = gr.Textbox(label="Population of interest", lines=1) intervention = gr.Textbox(label="Intervention/exposure of interest", lines=1) comparison = gr.Textbox(label="Comparison of interest", lines=1) outcome = gr.Textbox(label="Outcome of interest", lines=1) study_design = gr.Textbox(label="Study design of interest", lines=1) with gr.Column(): summary = gr.Textbox(label="PICOS Summary (optional)", lines=4) title = gr.Textbox(label="Title", lines=2, placeholder="Required") abstract = gr.Textbox(label="Abstract", lines=10, placeholder="Required") output_box = gr.Textbox(label="Model Response", lines=15, interactive=False) generate_btn = gr.Button("Generate") generate_btn.click( fn=respond_stream, inputs=[population, intervention, comparison, outcome, study_design, summary, title, abstract], outputs=[output_box] ) # Launch the app if __name__ == "__main__": demo.launch()